diff --git a/datacube/virtual/__init__.py b/datacube/virtual/__init__.py index 7f0f729d8a..cedd8355af 100644 --- a/datacube/virtual/__init__.py +++ b/datacube/virtual/__init__.py @@ -2,6 +2,7 @@ from .impl import VirtualProduct, Transformation, VirtualProductException from .transformations import MakeMask, ApplyMask, ToFloat, Rename, Select +from .transformations import Mean, year, month, week, day from .utils import reject_keys from datacube.model import Measurement @@ -14,7 +15,7 @@ class NameResolver: """ Apply a mapping from name to callable objects in a recipe. """ - def __init__(self, **lookup_table): + def __init__(self, lookup_table): self.lookup_table = lookup_table def construct(self, **recipe) -> VirtualProduct: @@ -22,53 +23,44 @@ def construct(self, **recipe) -> VirtualProduct: get = recipe.get - kind_keys = {key for key in recipe if key in ['product', 'transform', 'collate', 'juxtapose']} - if len(kind_keys) < 1: - raise VirtualProductException("virtual product kind not specified in {}".format(recipe)) - elif len(kind_keys) > 1: - raise VirtualProductException("ambiguous kind in {}".format(recipe)) - - if 'product' in recipe: - def resolve_func(key, value): - if key not in ['fuse_func', 'dataset_predicate']: - return value - - if callable(value): - return value + def lookup(name, namespace=None, kind='transformation'): + if callable(name): + return name + if namespace is not None and namespace in self.lookup_table and name in self.lookup_table[namespace]: + result = self.lookup_table[namespace][name] + else: try: - return import_function(value) + result = import_function(name) except (ImportError, AttributeError): - raise VirtualProductException("could not resolve function {} in {}".format(key, recipe)) + msg = "could not resolve {} {} in {}".format(kind, name, recipe) + raise VirtualProductException(msg) - return VirtualProduct({key: resolve_func(key, value) for key, value in recipe.items()}) + if not callable(result): + raise VirtualProductException("{} not callable in {}".format(kind, recipe)) - if 'transform' in recipe: - def resolve_transform(cls_name): - if callable(cls_name): - return cls_name + return result - if cls_name in self.lookup_table: - cls = self.lookup_table[cls_name] - else: - try: - cls = import_function(cls_name) - except (ImportError, AttributeError): - msg = "could not resolve transformation {} in {}".format(cls_name, recipe) - raise VirtualProductException(msg) - - if not callable(cls): - raise VirtualProductException("transformation not callable in {}".format(recipe)) + kind_keys = {key for key in recipe if key in ['product', 'transform', 'collate', 'juxtapose', 'aggregate']} + if len(kind_keys) < 1: + raise VirtualProductException("virtual product kind not specified in {}".format(recipe)) + elif len(kind_keys) > 1: + raise VirtualProductException("ambiguous kind in {}".format(recipe)) - return cls + if 'product' in recipe: + func_keys = ['fuse_func', 'dataset_predicate'] + return VirtualProduct({key: value if key not in func_keys else lookup(value, kind='function') + for key, value in recipe.items()}) + if 'transform' in recipe: cls_name = recipe['transform'] input_product = get('input') if input_product is None: raise VirtualProductException("no input for transformation in {}".format(recipe)) - return VirtualProduct(dict(transform=resolve_transform(cls_name), input=self.construct(**input_product), + return VirtualProduct(dict(transform=lookup(cls_name, 'transform'), + input=self.construct(**input_product), **reject_keys(recipe, ['transform', 'input']))) if 'collate' in recipe: @@ -85,14 +77,34 @@ def resolve_transform(cls_name): return VirtualProduct(dict(juxtapose=[self.construct(**child) for child in recipe['juxtapose']], **reject_keys(recipe, ['juxtapose']))) + if 'aggregate' in recipe: + cls_name = recipe['aggregate'] + input_product = get('input') + group_by = get('group_by') + + if input_product is None: + raise VirtualProductException("no input for aggregate in {}".format(recipe)) + if group_by is None: + raise VirtualProductException("no group_by for aggregate in {}".format(recipe)) + + return VirtualProduct(dict(aggregate=lookup(cls_name, 'aggregate'), + group_by=lookup(group_by, 'aggregate/group_by', kind='group_by'), + input=self.construct(**input_product), + **reject_keys(recipe, ['aggregate', 'input', 'group_by']))) + raise VirtualProductException("could not understand virtual product recipe: {}".format(recipe)) -DEFAULT_RESOLVER = NameResolver(make_mask=MakeMask, - apply_mask=ApplyMask, - to_float=ToFloat, - rename=Rename, - select=Select) +DEFAULT_RESOLVER = NameResolver({'transform': dict(make_mask=MakeMask, + apply_mask=ApplyMask, + to_float=ToFloat, + rename=Rename, + select=Select), + 'aggregate': dict(mean=Mean), + 'aggregate/group_by': dict(year=year, + month=month, + week=week, + day=day)}) def construct(**recipe: Mapping[str, Any]) -> VirtualProduct: diff --git a/datacube/virtual/impl.py b/datacube/virtual/impl.py index 00927b10a2..14cd2cced9 100644 --- a/datacube/virtual/impl.py +++ b/datacube/virtual/impl.py @@ -1,7 +1,5 @@ -# TODO: needs an aggregation phase (use xarray.DataArray.groupby?) # TODO: measurement dependency tracking # TODO: a mechanism to set settings for the leaf notes -# TODO: lineage tracking per observation # TODO: integrate GridWorkflow functionality (spatial binning) """ @@ -11,7 +9,7 @@ """ from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from functools import reduce from typing import Any, Dict, List, cast @@ -24,7 +22,7 @@ from datacube.api.grid_workflow import _fast_slice from datacube.api.query import Query, query_group_by, query_geopolygon from datacube.model import Measurement, DatasetType -from datacube.model.utils import xr_apply +from datacube.model.utils import xr_apply, xr_iter from .utils import qualified_name, merge_dicts from .utils import select_unique, select_keys, reject_keys, merge_search_terms @@ -70,6 +68,7 @@ def shape(self): return self.pile.shape + self.geobox.shape def __getitem__(self, chunk): + # TODO: test this functionality pile = self.pile return VirtualDatasetBox(_fast_slice(pile, chunk[:len(pile.shape)]), @@ -93,6 +92,33 @@ def split(self, dim='time'): for i in range(length): yield VirtualDatasetBox(pile.isel(**{dim: slice(i, i + 1)}), self.geobox, self.product_definitions) + def input_datasets(self): + def traverse(entry): + if isinstance(entry, Mapping): + if 'collate' in entry: + _, child = entry['collate'] + yield from traverse(child) + elif 'juxtapose' in entry: + for child in entry['juxtapose']: + yield from traverse(child) + else: + raise VirtualProductException("malformed box") + + elif isinstance(entry, Sequence): + yield from entry + + elif isinstance(entry, VirtualDatasetBox): + for _, _, child in xr_iter(entry.input_datasets()): + yield from child + + else: + raise VirtualProductException("malformed box") + + def worker(index, entry): + return set(traverse(entry)) + + return self.map(worker).pile + class Transformation(ABC): """ @@ -134,6 +160,7 @@ class VirtualProduct(Mapping): - transform: on-the-fly computation on data being loaded - collate: stack observations from products with the same set of measurements - juxtapose: put measurements from different products side-by-side + - aggregate: take (non-spatial) statistics of grouped data """ _GEOBOX_KEYS = ['output_crs', 'resolution', 'align'] @@ -180,6 +207,21 @@ def _transformation(self) -> Transformation: return cast(Transformation, obj) + @property + def _statistic(self) -> Transformation: + """ The `Transformation` object associated with an aggregate product. """ + cls = self['aggregate'] + + try: + obj = cls(**{key: value for key, value in self.items() + if key not in ['aggregate', 'input', 'group_by']}) + except TypeError: + raise VirtualProductException("transformation {} could not be instantiated".format(cls)) + + self._assert(isinstance(obj, Transformation), "not a transformation object: {}".format(obj)) + + return cast(Transformation, obj) + @property def _input(self) -> 'VirtualProduct': """ The input product of a transform product. """ @@ -204,8 +246,9 @@ def _product(self): @property def _kind(self): - """ One of product, transform, collate, or juxtapose. """ - candidates = [key for key in list(self) if key in ['product', 'transform', 'collate', 'juxtapose']] + """ One of product, transform, collate, juxtapose, or aggregate. """ + candidates = [key for key in list(self) + if key in ['product', 'transform', 'collate', 'juxtapose', 'aggregate']] self._assert(len(candidates) == 1, "ambiguous kind") return candidates[0] @@ -216,9 +259,8 @@ def output_measurements(self, product_definitions: Dict[str, DatasetType]) -> Di A dictionary mapping names to measurement metadata. :param product_definitions: a dictionary mapping product names to products (`DatasetType` objects) """ - get = self.get - if 'product' in self: + def _product_measurements_(): self._assert(self._product in product_definitions, "product {} not found in definitions".format(self._product)) @@ -226,7 +268,7 @@ def output_measurements(self, product_definitions: Dict[str, DatasetType]) -> Di measurements = {measurement['name']: Measurement(**measurement) for measurement in product.definition['measurements']} - if get('measurements') is None: + if self.get('measurements') is None: return measurements try: @@ -235,12 +277,12 @@ def output_measurements(self, product_definitions: Dict[str, DatasetType]) -> Di except KeyError as ke: raise VirtualProductException("could not find measurement: {}".format(ke.args)) - elif 'transform' in self: + def _transform_measurements_(): input_measurements = self._input.output_measurements(product_definitions) return self._transformation.measurements(input_measurements) - elif 'collate' in self: + def _collate_measurements_(): input_measurement_list = [child.output_measurements(product_definitions) for child in self._children] @@ -250,7 +292,7 @@ def output_measurements(self, product_definitions: Dict[str, DatasetType]) -> Di self._assert(set(child) == set(first), "child datasets do not all have the same set of measurements") - name = get('index_measurement_name') + name = self.get('index_measurement_name') if name is None: return first @@ -259,7 +301,7 @@ def output_measurements(self, product_definitions: Dict[str, DatasetType]) -> Di first.update({name: Measurement(name=name, dtype='int8', nodata=-1, units='1')}) return first - elif 'juxtapose' in self: + def _juxtapose_measurements_(): input_measurement_list = [child.output_measurements(product_definitions) for child in self._children] @@ -272,14 +314,28 @@ def output_measurements(self, product_definitions: Dict[str, DatasetType]) -> Di return result + def _aggregate_measurements_(): + input_measurements = self._input.output_measurements(product_definitions) + + return self._statistic.measurements(input_measurements) + + if 'product' in self: + return _product_measurements_() + elif 'transform' in self: + return _transform_measurements_() + elif 'collate' in self: + return _collate_measurements_() + elif 'juxtapose' in self: + return _juxtapose_measurements_() + elif 'aggregate' in self: + return _aggregate_measurements_() else: raise VirtualProductException("virtual product was not validated") def query(self, dc: Datacube, **search_terms: Dict[str, Any]) -> VirtualDatasetBag: """ Collection of datasets that match the query. """ - get = self.get - if 'product' in self: + def _product_query_(): product = dc.index.products.get_by_name(self._product) if product is None: raise VirtualProductException("could not find product {}".format(self._product)) @@ -297,7 +353,7 @@ def query(self, dc: Datacube, **search_terms: Dict[str, Any]) -> VirtualDatasetB datasets = select_datasets_inside_polygon(datasets, query.geopolygon) # should we put it in the Transformation class? - if get('dataset_predicate') is not None: + if self.get('dataset_predicate') is not None: datasets = [dataset for dataset in datasets if self['dataset_predicate'](dataset)] @@ -305,10 +361,10 @@ def query(self, dc: Datacube, **search_terms: Dict[str, Any]) -> VirtualDatasetB return VirtualDatasetBag(list(datasets), product.grid_spec, query.geopolygon, {product.name: product}) - elif 'transform' in self: + def _transform_or_aggregate_query_(): return self._input.query(dc, **search_terms) - elif 'collate' in self or 'juxtapose' in self: + def _collate_or_juxtapose_query_(): result = [child.query(dc, **search_terms) for child in self._children] @@ -316,7 +372,12 @@ def query(self, dc: Datacube, **search_terms: Dict[str, Any]) -> VirtualDatasetB select_unique([datasets.grid_spec for datasets in result]), select_unique([datasets.geopolygon for datasets in result]), merge_dicts([datasets.product_definitions for datasets in result])) - + if 'product' in self: + return _product_query_() + elif 'transform' in self or 'aggregate' in self: + return _transform_or_aggregate_query_() + elif 'collate' in self or 'juxtapose' in self: + return _collate_or_juxtapose_query_() else: raise VirtualProductException("virtual product was not validated") @@ -328,23 +389,22 @@ def group(self, datasets: VirtualDatasetBag, **search_terms: Dict[str, Any]) -> :param datasets: the `VirtualDatasetBag` to fetch data from :param query: to specify a spatial sub-region """ - grid_spec = datasets.grid_spec - geopolygon = datasets.geopolygon - - if 'product' in self: + def _product_group_(): # select only those inside the ROI # ROI could be smaller than the query for the `query` method + if query_geopolygon(**search_terms) is not None: geopolygon = query_geopolygon(**search_terms) selected = list(select_datasets_inside_polygon(datasets.pile, geopolygon)) else: + geopolygon = datasets.geopolygon selected = list(datasets.pile) # geobox merged = merge_search_terms(select_keys(self, self._NON_SPATIAL_KEYS), select_keys(search_terms, self._NON_SPATIAL_KEYS)) - geobox = output_geobox(datasets=selected, grid_spec=grid_spec, + geobox = output_geobox(datasets=selected, grid_spec=datasets.grid_spec, geopolygon=geopolygon, **select_keys(merged, self._GEOBOX_KEYS)) # group by time @@ -355,10 +415,10 @@ def group(self, datasets: VirtualDatasetBag, **search_terms: Dict[str, Any]) -> geobox, datasets.product_definitions) - elif 'transform' in self: + def _transform_group_(): return self._input.group(datasets, **search_terms) - elif 'collate' in self: + def _collate_group_(): self._assert('collate' in datasets.pile and len(datasets.pile['collate']) == len(self._children), "invalid dataset pile") @@ -376,11 +436,11 @@ def tag(_, value): for source_index, (product, dataset_pile) in enumerate(zip(self._children, datasets.pile['collate']))] - return VirtualDatasetBox(xarray.concat([grouped.pile for grouped in groups], dim='time'), + return VirtualDatasetBox(xarray.concat([grouped.pile for grouped in groups], dim=self.get('dim', 'time')), select_unique([grouped.geobox for grouped in groups]), merge_dicts([grouped.product_definitions for grouped in groups])) - elif 'juxtapose' in self: + def _juxtapose_group_(): self._assert('juxtapose' in datasets.pile and len(datasets.pile['juxtapose']) == len(self._children), "invalid dataset pile") @@ -398,6 +458,29 @@ def tuplify(indexes, _): select_unique([grouped.geobox for grouped in groups]), merge_dicts([grouped.product_definitions for grouped in groups])) + def _aggregate_group_(): + grouped = self._input.group(datasets, **search_terms) + dim = self.get('dim', 'time') + + def to_box(value): + return xarray.DataArray([VirtualDatasetBox(value, grouped.geobox, grouped.product_definitions)], + dims=['_fake_']) + + result = grouped.pile.groupby(self['group_by'](grouped.pile[dim])).apply(to_box).squeeze('_fake_') + result[dim].attrs.update(grouped.pile[dim].attrs) + + return VirtualDatasetBox(result, grouped.geobox, grouped.product_definitions) + + if 'product' in self: + return _product_group_() + elif 'transform' in self: + return _transform_group_() + elif 'collate' in self: + return _collate_group_() + elif 'juxtapose' in self: + return _juxtapose_group_() + elif 'aggregate' in self: + return _aggregate_group_() else: raise VirtualProductException("virtual product was not validated") @@ -409,7 +492,7 @@ def fetch(self, grouped: VirtualDatasetBox, **load_settings: Dict[str, Any]) -> product_definitions = grouped.product_definitions _ = self.output_measurements(product_definitions) - if 'product' in self: + def _product_fetch_(): merged = merge_search_terms(select_keys(self, self._LOAD_KEYS), select_keys(load_settings, self._LOAD_KEYS)) @@ -423,10 +506,10 @@ def fetch(self, grouped: VirtualDatasetBox, **load_settings: Dict[str, Any]) -> return apply_aliases(result, product_definitions[self._product], list(measurements)) - elif 'transform' in self: + def _transform_fetch_(): return self._transformation.compute(self._input.fetch(grouped, **load_settings)) - elif 'collate' in self: + def _collate_fetch_(): def is_from(source_index): def result(_, value): self._assert('collate' in value, "malformed dataset pile in collate") @@ -465,9 +548,11 @@ def fetch_child(child, source_index, r): non_empty = [g for g in groups if g is not None] - return xarray.concat(non_empty, dim='time').assign_attrs(**select_unique([g.attrs for g in non_empty])) + return xarray.concat(non_empty, + dim=self.get('dim', 'time')).assign_attrs(**select_unique([g.attrs + for g in non_empty])) - elif 'juxtapose' in self: + def _juxtapose_fetch_(): def select_child(source_index): def result(_, value): self._assert('juxtapose' in value, "malformed dataset pile in juxtapose") @@ -484,6 +569,34 @@ def fetch_recipe(source_index): return xarray.merge(groups).assign_attrs(**select_unique([g.attrs for g in groups])) + def _aggregate_fetch_(): + dim = self.get('dim', 'time') + + def xr_map(array, func): + # convenient function close to `xr_apply` in spirit + coords = {key: value.values for key, value in array.coords.items()} + for i in numpy.ndindex(array.shape): + yield func({key: value[i] for key, value in coords.items()}, array.values[i]) + + def statistic(coords, value): + data = self._input.fetch(value, **load_settings) + result = self._statistic.compute(data) + result.coords[dim] = coords[dim] + return result + + groups = list(xr_map(grouped.pile, statistic)) + return xarray.concat(groups, dim=dim).assign_attrs(**select_unique([g.attrs for g in groups])) + + if 'product' in self: + return _product_fetch_() + elif 'transform' in self: + return _transform_fetch_() + elif 'collate' in self: + return _collate_fetch_() + elif 'juxtapose' in self: + return _juxtapose_fetch_() + elif 'aggregate' in self: + return _aggregate_fetch_() else: raise VirtualProductException("virtual product was not validated") @@ -508,6 +621,12 @@ def reconstruct(product): children = [reconstruct(child) for child in product['juxtapose']] return dict(juxtapose=children, **reject_keys(product, ['juxtapose'])) + if 'aggregate' in product: + input_product = reconstruct(product['input']) + return dict(aggregate=qualified_name(product['aggregate']), + group_by=qualified_name(product['group_by']), + input=input_product, **reject_keys(product, ['input', 'aggregate', 'group_by'])) + else: raise VirtualProductException("virtual product was not validated") @@ -518,4 +637,4 @@ def load(self, dc: Datacube, **query: Dict[str, Any]) -> xarray.Dataset: """ Mimic `datacube.Datacube.load`. For illustrative purposes. May be removed in the future. """ datasets = self.query(dc, **query) grouped = self.group(datasets, **query) - return self.fetch(grouped, **query).sortby('time') + return self.fetch(grouped, **query) diff --git a/datacube/virtual/transformations.py b/datacube/virtual/transformations.py index 740421d0c0..0a8e112b6a 100644 --- a/datacube/virtual/transformations.py +++ b/datacube/virtual/transformations.py @@ -156,3 +156,36 @@ def compute(self, data): return data.drop([measurement for measurement in data.data_vars if measurement not in self.measurement_names]) + + +def year(time): + return time.astype('datetime64[Y]') + + +def month(time): + return time.astype('datetime64[M]') + + +def week(time): + return time.astype('datetime64[W]') + + +def day(time): + return time.astype('datetime64[D]') + + +# TODO: all time stats + +class Mean(Transformation): + """ + Take the mean of the measurements. + """ + + def __init__(self, dim='time'): + self.dim = dim + + def measurements(self, input_measurements): + return input_measurements + + def compute(self, data): + return data.mean(dim=self.dim) diff --git a/tests/api/test_virtual.py b/tests/api/test_virtual.py index f8d321a66f..948e803020 100644 --- a/tests/api/test_virtual.py +++ b/tests/api/test_virtual.py @@ -304,3 +304,29 @@ def test_aliases(dc, query): assert 'verde' in data assert 'green' not in data + + +def test_aggregate(dc, query): + aggr = construct_from_yaml(""" + aggregate: mean + group_by: month + input: + transform: to_float + input: + collate: + - product: ls7_nbar_albers + measurements: [blue] + - product: ls8_nbar_albers + measurements: [blue] + """) + + measurements = aggr.output_measurements({product.name: product + for product in dc.index.products.get_all()}) + assert 'blue' in measurements + + with mock.patch('datacube.virtual.impl.Datacube') as mock_datacube: + mock_datacube.load_data = load_data + mock_datacube.group_datasets = group_datasets + data = aggr.load(dc, **query) + + assert data.time.shape == (2,)