From 4c456844326de64e7426521cd810b308c89885c6 Mon Sep 17 00:00:00 2001 From: Damien Ayers Date: Thu, 5 Jan 2017 11:57:24 +1100 Subject: [PATCH] Allow ignoring exceptions while reading data --- datacube/api/core.py | 10 ++++--- datacube/api/grid_workflow.py | 4 +-- datacube/storage/storage.py | 56 +++++++++++++++++++++++++++-------- tests/storage/test_storage.py | 46 +++++++++++++++++++++++++++- 4 files changed, 96 insertions(+), 20 deletions(-) diff --git a/datacube/api/core.py b/datacube/api/core.py index 689476d5d9..d2cffbe73e 100644 --- a/datacube/api/core.py +++ b/datacube/api/core.py @@ -473,7 +473,7 @@ def product_data(*args, **kwargs): return Datacube.load_data(*args, **kwargs) @staticmethod - def load_data(sources, geobox, measurements, fuse_func=None, dask_chunks=None): + def load_data(sources, geobox, measurements, fuse_func=None, dask_chunks=None, ignore_errors=False): """ Load data from :meth:`group_datasets` into an :class:`xarray.Dataset`. @@ -504,7 +504,8 @@ def load_data(sources, geobox, measurements, fuse_func=None, dask_chunks=None): def data_func(measurement): data = numpy.full(sources.shape + geobox.shape, measurement['nodata'], dtype=measurement['dtype']) for index, datasets in numpy.ndenumerate(sources.values): - _fuse_measurement(data[index], datasets, geobox, measurement, fuse_func) + _fuse_measurement(data[index], datasets, geobox, measurement, fuse_func=fuse_func, + ignore_errors=ignore_errors) return data else: def data_func(measurement): @@ -561,14 +562,15 @@ def fuse_lazy(datasets, geobox, measurement, fuse_func=None, prepend_dims=0): return data.reshape(prepend_shape + geobox.shape) -def _fuse_measurement(dest, datasets, geobox, measurement, fuse_func=None): +def _fuse_measurement(dest, datasets, geobox, measurement, ignore_errors=False, fuse_func=None): reproject_and_fuse([DatasetSource(dataset, measurement['name']) for dataset in datasets], dest, geobox.affine, geobox.crs, dest.dtype.type(measurement['nodata']), resampling=measurement.get('resampling_method', 'nearest'), - fuse_func=fuse_func) + fuse_func=fuse_func, + ignore_errors=ignore_errors) def get_bounds(datasets, crs): diff --git a/datacube/api/grid_workflow.py b/datacube/api/grid_workflow.py index 21daba7345..a8003e9bcf 100644 --- a/datacube/api/grid_workflow.py +++ b/datacube/api/grid_workflow.py @@ -304,7 +304,7 @@ def list_tiles(self, cell_index=None, **query): return self.tile_sources(observations, query_group_by(**query)) @staticmethod - def load(tile, measurements=None, dask_chunks=None, fuse_func=None, resampling=None): + def load(tile, measurements=None, dask_chunks=None, fuse_func=None, resampling=None, ignore_errors=False): """ Load data for a cell/tile. @@ -348,7 +348,7 @@ def load(tile, measurements=None, dask_chunks=None, fuse_func=None, resampling=N measurements = set_resampling_method(measurements, resampling) dataset = Datacube.load_data(tile.sources, tile.geobox, measurements.values(), dask_chunks=dask_chunks, - fuse_func=fuse_func) + fuse_func=fuse_func, ignore_errors=ignore_errors) return dataset diff --git a/datacube/storage/storage.py b/datacube/storage/storage.py index 95f130f429..c71672f3bc 100644 --- a/datacube/storage/storage.py +++ b/datacube/storage/storage.py @@ -96,16 +96,19 @@ def _no_fractional_translate(affine, eps=0.01): return abs(affine.c % 1.0) < eps and abs(affine.f % 1.0) < eps -def reproject(source, dest, dst_transform, dst_nodata, dst_projection, resampling): +def read_from_source(source, dest, dst_transform, dst_nodata, dst_projection, resampling): """ Read from `source` into `dest`, reprojecting if necessary. + + :param BaseRasterDataSource source: Data source + :param numpy.ndarray dest: Data destination """ with source.open() as src: array_transform = ~src.transform * dst_transform if (src.crs == dst_projection and _no_scale(array_transform) and (resampling == Resampling.nearest or _no_fractional_translate(array_transform))): - dydx = (int(round(array_transform.f)), int(round(array_transform.c))) - read, write, shape = zip(*map(_calc_offsets, dydx, src.shape, dest.shape)) + dy_dx = int(round(array_transform.f)), int(round(array_transform.c)) + read, write, shape = zip(*map(_calc_offsets, dy_dx, src.shape, dest.shape)) dest.fill(dst_nodata) if all(shape): @@ -126,10 +129,28 @@ def reproject(source, dest, dst_transform, dst_nodata, dst_projection, resamplin NUM_THREADS=OPTIONS['reproject_threads']) +@contextmanager +def ignore_if(ignore_errors): + """Ignore Exceptions raised within this block if ignore_errors is True""" + if ignore_errors: + try: + yield + except OSError as e: + _LOG.warning('Ignoring Exception: %s', e) + else: + yield + + def reproject_and_fuse(sources, destination, dst_transform, dst_projection, dst_nodata, - resampling='nearest', fuse_func=None): + resampling='nearest', fuse_func=None, ignore_errors=False): """ Reproject and fuse `sources` into a 2D numpy array `destination`. + + :param List[BaseRasterDataSource] sources: Data sources to open and read from + :param numpy.ndarray destination: ndarray of appropriate size to read data into + :type resampling: str + :type fuse_func: callable or None + :param bool ignore_errors: Carry on in the face of adversity and failing reads. """ assert len(destination.shape) == 2 @@ -144,19 +165,20 @@ def copyto_fuser(dest, src): fuse_func = fuse_func or copyto_fuser + destination.fill(dst_nodata) if len(sources) == 0: - destination.fill(dst_nodata) return destination elif len(sources) == 1: - reproject(sources[0], destination, dst_transform, dst_nodata, dst_projection, resampling) + with ignore_if(ignore_errors): + read_from_source(sources[0], destination, dst_transform, dst_nodata, dst_projection, resampling) return destination else: - destination.fill(dst_nodata) - + # Muitiple sources, we need to fuse them together into a single array buffer_ = numpy.empty(destination.shape, dtype=destination.dtype) for source in sources: - reproject(source, buffer_, dst_transform, dst_nodata, dst_projection, resampling) - fuse_func(destination, buffer_) + with ignore_if(ignore_errors): + read_from_source(source, buffer_, dst_transform, dst_nodata, dst_projection, resampling) + fuse_func(destination, buffer_) return destination @@ -233,7 +255,7 @@ def reproject(self, dest, dst_transform, dst_crs, dst_nodata, resampling, **kwar class BaseRasterDataSource(object): """ - Interface used by fuse_sources and reproject + Interface used by fuse_sources and read_from_source """ def __init__(self, filename, nodata): self.filename = filename @@ -250,6 +272,7 @@ def get_crs(self): @contextmanager def open(self): + """Context manager which returns a `BandDataSource`""" try: _LOG.debug("opening %s", self.filename) with rasterio.open(self.filename) as src: @@ -281,9 +304,9 @@ def open(self): raise e -class BasicRasterDataSource(BaseRasterDataSource): +class RasterFileDataSource(BaseRasterDataSource): def __init__(self, filename, bandnumber, nodata=None, crs=None, transform=None): - super(BasicRasterDataSource, self).__init__(filename, nodata) + super(RasterFileDataSource, self).__init__(filename, nodata) self.bandnumber = bandnumber self.crs = crs self.transform = transform @@ -351,6 +374,7 @@ def _url2rasterio(url_str, fmt, layer): class DatasetSource(BaseRasterDataSource): + """Data source for reading from a Datacube Dataset""" def __init__(self, dataset, measurement_id): self._dataset = dataset self._measurement = dataset.measurements[measurement_id] @@ -399,6 +423,12 @@ def create_netcdf_storage_unit(filename, :param pathlib.Path filename: filename to write to :param datacube.model.CRS crs: Datacube CRS object defining the spatial projection + :param dict coordinates: Dict of named `datacube.model.Coordinate`s to create + :param dict variables: Dict of named `datacube.model.Variable`s to create + :param dict variable_params: + Dict of dicts, with keys matching variable names, of extra parameters for variables + :param dict global_attributes: named global attributes to add to output file + :param dict netcdfparams: Extra parameters to use when creating netcdf file :return: open netCDF4.Dataset object, ready for writing to """ filename = Path(filename) diff --git a/tests/storage/test_storage.py b/tests/storage/test_storage.py index 6940c5c17c..62204fbedc 100644 --- a/tests/storage/test_storage.py +++ b/tests/storage/test_storage.py @@ -5,6 +5,7 @@ from affine import Affine, identity import xarray import mock +import pytest from datacube.model import GeoBox, CRS from datacube.storage.storage import write_dataset_to_netcdf, reproject_and_fuse @@ -86,7 +87,8 @@ def test_mixed_result_when_first_source_partially_empty(): assert (output_data == [[1, 1], [2, 2]]).all() -def _mock_datasetsource(value, crs, shape): +def _mock_datasetsource(value, crs=None, shape=(2, 2)): + crs = crs or mock.MagicMock() dataset_source = mock.MagicMock() rio_reader = dataset_source.open.return_value.__enter__.return_value rio_reader.crs = crs @@ -99,3 +101,45 @@ def _mock_datasetsource(value, crs, shape): # dest[:] = value # rio_reader.reproject.side_effect = fill_array return dataset_source + + +def test_read_from_broken_source(): + crs = mock.MagicMock() + shape = (2, 2) + no_data = -1 + + source1 = _mock_datasetsource([[1, 1], [no_data, no_data]], crs=crs, shape=shape) + source2 = _mock_datasetsource([[2, 2], [2, 2]], crs=crs, shape=shape) + sources = [source1, source2] + + rio_reader = source1.open.return_value.__enter__.return_value + rio_reader.read.side_effect = OSError('Read or write failed') + + output_data = numpy.full(shape, fill_value=no_data, dtype='int16') + + # Check exception is raised + with pytest.raises(OSError): + reproject_and_fuse(sources, output_data, dst_transform=identity, + dst_projection=crs, dst_nodata=no_data) + + # Check can ignore errors + reproject_and_fuse(sources, output_data, dst_transform=identity, + dst_projection=crs, dst_nodata=no_data, ignore_errors=True) + + assert (output_data == [[2, 2], [2, 2]]).all() + + +def _create_broken_netcdf(tmpdir): + import os + output_path = str(tmpdir / 'broken_netcdf_file.nc') + with netCDF4.Dataset('broken_netcdf_file.nc', 'w') as nco: + nco.createDimension('x', 50) + nco.createDimension('y', 50) + nco.createVariable('blank', 'int16', ('y', 'x')) + + with open(output_path, 'rb+') as filehandle: + filehandle.seek(-3, os.SEEK_END) + filehandle.truncate() + + with netCDF4.Dataset(output_path) as nco: + blank = nco.data_vars['blank']