Skip to content

Commit

Permalink
enabled output of merge to have additional bands (#1933)
Browse files Browse the repository at this point in the history
* enabled output of merge to have additional bands

* return dataset index in merge callable

* added custom callable test func

* add support for specifying output dtype when merging

* removed funcsigs and used kwargs for copyto func
  • Loading branch information
normanb authored May 29, 2020
1 parent 98c6920 commit 5af075c
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 15 deletions.
50 changes: 35 additions & 15 deletions rasterio/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
MERGE_METHODS = ('first', 'last', 'min', 'max')


def merge(datasets, bounds=None, res=None, nodata=None, precision=10, indexes=None,
method='first'):
def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10,
indexes=None, output_count=None, method='first'):
"""Copy valid pixels from input files to an output file.
All files must have the same number of bands, data type, and
Expand Down Expand Up @@ -45,10 +45,16 @@ def merge(datasets, bounds=None, res=None, nodata=None, precision=10, indexes=No
nodata: float, optional
nodata value to use in output file. If not set, uses the nodata value
in the first input raster.
dtype: numpy dtype or string
dtype to use in outputfile. If not set, uses the dtype value in the
first input raster.
precision: float, optional
Number of decimal points of precision when computing inverse transform.
indexes : list of ints or a single int, optional
bands to read and merge
output_count: int, optional
If using callable it may be useful to have additional bands in the output
in addition to the indexes specified for read
method : str or callable
pre-defined method:
first: reverse painting
Expand All @@ -57,7 +63,7 @@ def merge(datasets, bounds=None, res=None, nodata=None, precision=10, indexes=No
max: pixel-wise max of existing and new
or custom callable with signature:
def function(old_data, new_data, old_nodata, new_nodata):
def function(old_data, new_data, old_nodata, new_nodata, index=None, roff=None, coff=None):
Parameters
----------
Expand All @@ -69,6 +75,12 @@ def function(old_data, new_data, old_nodata, new_nodata):
old_nodata, new_data : array_like
boolean masks where old/new data is nodata
same shape as old_data
index: int
index of the current dataset within the merged dataset collection
roff: int
row offset in base array
coff: int
column offset in base array
Returns
-------
Expand All @@ -86,19 +98,22 @@ def function(old_data, new_data, old_nodata, new_nodata):
first = datasets[0]
first_res = first.res
nodataval = first.nodatavals[0]
dtype = first.dtypes[0]
dt = first.dtypes[0]

if method not in MERGE_METHODS and not callable(method):
raise ValueError('Unknown method {0}, must be one of {1} or callable'
.format(method, MERGE_METHODS))

# Determine output band count
if indexes is None:
output_count = first.count
src_count = first.count
elif isinstance(indexes, int):
output_count = 1
src_count = indexes
else:
output_count = len(indexes)
src_count = len(indexes)

if not output_count:
output_count = src_count

# Extent from option or extent of all inputs
if bounds:
Expand Down Expand Up @@ -137,8 +152,12 @@ def function(old_data, new_data, old_nodata, new_nodata):
logger.debug("Output width: %d, height: %d", output_width, output_height)
logger.debug("Adjusted bounds: %r", (dst_w, dst_s, dst_e, dst_n))

if dtype is not None:
dt = dtype
logger.debug("Set dtype: %s", dt)

# create destination array
dest = np.zeros((output_count, output_height, output_width), dtype=dtype)
dest = np.zeros((output_count, output_height, output_width), dtype=dt)

if nodata is not None:
nodataval = nodata
Expand Down Expand Up @@ -168,25 +187,25 @@ def function(old_data, new_data, old_nodata, new_nodata):
nodataval = 0

if method == 'first':
def copyto(old_data, new_data, old_nodata, new_nodata):
def copyto(old_data, new_data, old_nodata, new_nodata, **kwargs):
mask = np.logical_and(old_nodata, ~new_nodata)
old_data[mask] = new_data[mask]

elif method == 'last':
def copyto(old_data, new_data, old_nodata, new_nodata):
def copyto(old_data, new_data, old_nodata, new_nodata, **kwargs):
mask = ~new_nodata
old_data[mask] = new_data[mask]

elif method == 'min':
def copyto(old_data, new_data, old_nodata, new_nodata):
def copyto(old_data, new_data, old_nodata, new_nodata, **kwargs):
mask = np.logical_and(~old_nodata, ~new_nodata)
old_data[mask] = np.minimum(old_data[mask], new_data[mask])

mask = np.logical_and(old_nodata, ~new_nodata)
old_data[mask] = new_data[mask]

elif method == 'max':
def copyto(old_data, new_data, old_nodata, new_nodata):
def copyto(old_data, new_data, old_nodata, new_nodata, **kwargs):
mask = np.logical_and(~old_nodata, ~new_nodata)
old_data[mask] = np.maximum(old_data[mask], new_data[mask])

Expand All @@ -199,7 +218,7 @@ def copyto(old_data, new_data, old_nodata, new_nodata):
else:
raise ValueError(method)

for src in datasets:
for idx, src in enumerate(datasets):
# Real World (tm) use of boundless reads.
# This approach uses the maximum amount of memory to solve the
# problem. Making it more efficient is a TODO.
Expand All @@ -226,7 +245,7 @@ def copyto(old_data, new_data, old_nodata, new_nodata):
# 4. Read data in source window into temp
trows, tcols = (
int(round(dst_window.height)), int(round(dst_window.width)))
temp_shape = (output_count, trows, tcols)
temp_shape = (src_count, trows, tcols)
temp = src.read(out_shape=temp_shape, window=src_window,
boundless=False, masked=True, indexes=indexes)

Expand All @@ -242,6 +261,7 @@ def copyto(old_data, new_data, old_nodata, new_nodata):
region_nodata = region == nodataval
temp_nodata = temp.mask

copyto(region, temp, region_nodata, temp_nodata)
copyto(region, temp, region_nodata, temp_nodata,
index=idx, roff=roff, coff=coff)

return dest, output_transform
35 changes: 35 additions & 0 deletions tests/test_rio_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,41 @@ def test_merge_overlapping(test_data_dir_overlapping, runner):
assert np.all(data == expected)


def test_merge_overlapping_callable_long(test_data_dir_overlapping, runner):
outputname = str(test_data_dir_overlapping.join('merged.tif'))
inputs = [str(x) for x in test_data_dir_overlapping.listdir()]
datasets = [rasterio.open(x) for x in inputs]
test_merge_overlapping_callable_long.index = 0

def mycallable(old_data, new_data, old_nodata, new_nodata,
index=None, roff=None, coff=None):
assert old_data.shape[0] == 5
assert new_data.shape[0] == 1
assert test_merge_overlapping_callable_long.index == index
test_merge_overlapping_callable_long.index += 1

merge(datasets, output_count=5, method=mycallable)


def test_custom_callable_merge(test_data_dir_overlapping, runner):
inputs = ['tests/data/world.byte.tif'] * 3
datasets = [rasterio.open(x) for x in inputs]
meta = datasets[0].meta
output_count = 4

def mycallable(old_data, new_data, old_nodata, new_nodata,
index=None, roff=None, coff=None):
# input data are bytes, test output doesn't overflow
old_data[index] = (index + 1) * 259 # use a number > 255 but divisible by 3 for testing
# update additional band that we specified in output_count
old_data[3, :, :] += index

arr, _ = merge(datasets, output_count=output_count, method=mycallable, dtype=np.uint64)

np.testing.assert_array_equal(np.mean(arr[:3], axis=0), 518)
np.testing.assert_array_equal(arr[3, :, :], 3)


# Fixture to create test datasets within temporary directory
@fixture(scope='function')
def test_data_dir_float(tmpdir):
Expand Down

0 comments on commit 5af075c

Please sign in to comment.