Skip to content

Commit

Permalink
Avoid masked arrays in merge.py
Browse files Browse the repository at this point in the history
Also rely on GDAL for producing masks for masked arrays and not
our own masking logic.
  • Loading branch information
Sean Gillies committed Mar 4, 2015
1 parent 70eebc4 commit 4fc920b
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 24 deletions.
1 change: 1 addition & 0 deletions rasterio/_gdal.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ cdef extern from "gdal.h" nogil:
int GDALGetRasterColorInterpretation (void *hBand)
int GDALSetRasterColorInterpretation (void *hBand, int)

int GDALGetMaskFlags (void *hBand)
void *GDALGetMaskBand (void *hBand)
int GDALCreateMaskBand (void *hDS, int flags)

Expand Down
91 changes: 70 additions & 21 deletions rasterio/_io.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,8 @@ cdef class RasterReader(_base.DatasetReader):
preferentially used by callers.
"""

cdef void *hband = NULL

return2d = False
if indexes is None:
indexes = self.indexes
Expand Down Expand Up @@ -689,8 +691,30 @@ cdef class RasterReader(_base.DatasetReader):
(out.shape, win_shape))
if masked is None:
masked = hasattr(out, 'mask')

# Masking
# -------
#
# If masked is True, we read the GDAL mask bands
# using read_masks(), invert them and use them in constructing
# masked arrays.
#
# If masked is None, we check the GDAL mask flags using
# GDALGetMaskFlags. If GMF_ALL_VALID, we do not create a
# masked array. Else, we call read_masks() and process as
# above.

if masked is None:
masked = any([x is not None for x in nodatavals])
mask_flags = [0]*self.count
for i, j in zip(range(self.count), self.indexes):
hband = _gdal.GDALGetRasterBand(self._hds, j)
mask_flags[i] = _gdal.GDALGetMaskFlags(hband)

masked = any([flag & 0x01 == 0 for flag in mask_flags])

log.debug("masked: %s", masked)
log.debug("mask_flags: %r", mask_flags)

if out is None:
out = np.zeros(win_shape, dtype)
for ndv, arr in zip(
Expand All @@ -703,6 +727,15 @@ cdef class RasterReader(_base.DatasetReader):
if not boundless or not window:
out = self._read(indexes, out, window, dtype)

if masked:
mask = np.empty(out.shape, 'uint8')
mask = ~self._read(
indexes, mask, window, 'uint8', masks=True).astype('bool')
kwargs = {'mask': mask}
if nodatavals:
kwargs['fill_value'] = nodatavals[0]
out = np.ma.array(out, **kwargs)

else:
# Compute the overlap between the dataset and the boundless window.
overlap = ((
Expand All @@ -718,9 +751,21 @@ cdef class RasterReader(_base.DatasetReader):
overlap_w = overlap[1][1] - overlap[1][0]
scaling_h = float(out.shape[-2:][0])/window_h
scaling_w = float(out.shape[-2:][1])/window_w
buffer_shape = (int(overlap_h*scaling_h), int(overlap_w*scaling_w))
buffer_shape = (
int(overlap_h*scaling_h), int(overlap_w*scaling_w))
data = np.empty(win_shape[:-2] + buffer_shape, dtype)
data = self._read(indexes, data, overlap, dtype)

if masked:
mask = np.zeros(win_shape[:-2] + buffer_shape, 'uint8')
mask = ~self._read(
indexes, mask, overlap, 'uint8', masks=True
).astype('bool')
kwargs = {'mask': mask}
if nodatavals:
kwargs['fill_value'] = nodatavals[0]
data = np.ma.array(data, **kwargs)

else:
data = None

Expand All @@ -737,26 +782,30 @@ cdef class RasterReader(_base.DatasetReader):
out if len(out.shape) == 3 else [out],
data if len(data.shape) == 3 else [data]):
dst[roff:roff+data_h, coff:coff+data_w] = src
if hasattr(dst, 'mask'):
dst.mask[roff:roff+data_h, coff:coff+data_w] = src


# if masked:
#
# if len(set(nodatavals)) == 1:
# if nodatavals[0] is None:
# out = np.ma.masked_array(out, copy=False)
# elif np.isnan(nodatavals[0]):
# out = np.ma.masked_where(np.isnan(out), out, copy=False)
# else:
# out = np.ma.masked_equal(out, nodatavals[0], copy=False)
# else:
# out = np.ma.masked_array(out, copy=False)
# for aix in range(len(indexes)):
# if nodatavals[aix] is None:
# band_mask = False
# elif np.isnan(nodatavals[aix]):
# band_mask = np.isnan(out[aix])
# else:
# band_mask = out[aix] == nodatavals[aix]
# out[aix].mask = band_mask

# Masking the output. TODO: explain the logic better.
if masked:
if len(set(nodatavals)) == 1:
if nodatavals[0] is None:
out = np.ma.masked_array(out, copy=False)
elif np.isnan(nodatavals[0]):
out = np.ma.masked_where(np.isnan(out), out, copy=False)
else:
out = np.ma.masked_equal(out, nodatavals[0], copy=False)
else:
out = np.ma.masked_array(out, copy=False)
for aix in range(len(indexes)):
if nodatavals[aix] is None:
band_mask = False
elif np.isnan(nodatavals[aix]):
band_mask = np.isnan(out[aix])
else:
band_mask = out[aix] == nodatavals[aix]
out[aix].mask = band_mask
if return2d:
out.shape = out.shape[1:]

Expand Down
11 changes: 8 additions & 3 deletions rasterio/rio/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,21 @@ def merge(ctx, files, driver, bounds, res, nodata):
out=data,
window=window,
boundless=True,
masked=True)
masked=False)
mask = np.zeros_like(dest, 'uint8')
mask = src.read_masks(
out=mask,
window=window,
boundless=True)
np.copyto(dest, data,
where=np.logical_and(
dest==nodataval, data.mask==False))
dest==nodataval, mask>0))

if dst.mode == 'r+':
data = dst.read(masked=True)
np.copyto(dest, data,
where=np.logical_and(
dest==nodataval, data.mask==False))
dest==nodataval, mask>0))

dst.write(dest)
dst.close()
Expand Down

0 comments on commit 4fc920b

Please sign in to comment.