Skip to content

Commit

Permalink
valid_data_mask tests
Browse files Browse the repository at this point in the history
  • Loading branch information
uchchwhash committed Jul 15, 2019
1 parent 3f37d5f commit d1907c9
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
15 changes: 13 additions & 2 deletions datacube/storage/_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

from datacube.utils import ignore_exceptions_if
from datacube.utils.math import dtype_is_float
from datacube.utils.geometry import GeoBox, Coordinate, roi_is_empty
from datacube.model import Measurement
from datacube.drivers._types import ReaderDriver
Expand All @@ -26,13 +27,23 @@
ProgressFunction = Callable[[int, int], Any] # pylint: disable=invalid-name


def _default_fuser(dst: np.ndarray, src: np.ndarray, dst_nodata: float) -> None:
def _default_fuser(dst: np.ndarray, src: np.ndarray, dst_nodata) -> None:
""" Overwrite only those pixels in `dst` with `src` that are "not valid"
For every pixel in dst that equals to dst_nodata replace it with pixel
from src.
"""
where_nodata = (dst == dst_nodata) if not np.isnan(dst_nodata) else np.isnan(dst)
if dtype_is_float(dst.dtype):
if dst_nodata is None or np.isnan(dst_nodata):
where_nodata = np.isnan(dst)
else:
where_nodata = np.isnan(dst) | (dst == dst_nodata)
else:
if dst_nodata is None:
where_nodata = np.full_like(dst, False, dtype=np.bool)
else:
where_nodata = dst == dst_nodata

np.copyto(dst, src, where=where_nodata)


Expand Down
10 changes: 6 additions & 4 deletions datacube/storage/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,17 @@ def valid_data_mask(data):
if not isinstance(data, DataArray):
raise TypeError('valid_data_mask not supported for type {}'.format(type(data)))

nodata = data.attrs.get('nodata', None)

if dtype_is_float(data.dtype):
if 'nodata' not in data.attrs or numpy.isnan(data.attrs['nodata']):
if nodata is None or numpy.isnan(nodata):
return ~xarray.ufuncs.isnan(data)
return (data != data.nodata) & ~xarray.ufuncs.isnan(data)
return (data != nodata) & ~xarray.ufuncs.isnan(data)

# not float
if 'nodata' not in data.attrs:
if nodata is None:
return xarray.full_like(data, True, dtype=numpy.bool)
return data != data.nodata
return data != nodata


def mask_valid_data(data, keep_attrs=True):
Expand Down
35 changes: 35 additions & 0 deletions tests/api/test_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,38 @@ def test_valid_data_mask():

output_da = valid_data_mask(data_array)
assert output_da.equals(expected_data_array)

expected_data_array = DataArray(np.array([[True, True, True], [True, True, True], [True, True, True]],
dtype='bool'))
data_array = DataArray([[1, -999, -999], [2, 3, -999], [-999, -999, -999]])
dataset = Dataset(data_vars={'var_one': data_array})

output_ds = valid_data_mask(dataset)
assert output_ds.data_vars['var_one'].equals(expected_data_array)

output_da = valid_data_mask(data_array)
assert output_da.equals(expected_data_array)

expected_data_array = DataArray(np.array([[True, False, False], [True, True, False], [False, False, False]],
dtype='bool'))

data_array = DataArray([[1, -999, -999], [2, 3, -999], [-999, -999, float('nan')]], attrs=attrs)
dataset = Dataset(data_vars={'var_one': data_array})

output_ds = valid_data_mask(dataset)
assert output_ds.data_vars['var_one'].equals(expected_data_array)

output_da = valid_data_mask(data_array)
assert output_da.equals(expected_data_array)

expected_data_array = DataArray(np.array([[True, True, True], [True, True, True], [True, True, False]],
dtype='bool'))

data_array = DataArray([[1, -999, -999], [2, 3, -999], [-999, -999, float('nan')]])
dataset = Dataset(data_vars={'var_one': data_array})

output_ds = valid_data_mask(dataset)
assert output_ds.data_vars['var_one'].equals(expected_data_array)

output_da = valid_data_mask(data_array)
assert output_da.equals(expected_data_array)

0 comments on commit d1907c9

Please sign in to comment.