Skip to content

Commit

Permalink
Merge fe52b0b into 669d86b
Browse files Browse the repository at this point in the history
  • Loading branch information
sgillies committed Dec 3, 2020
2 parents 669d86b + fe52b0b commit ee6fb20
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 36 deletions.
7 changes: 5 additions & 2 deletions CHANGES.txt
@@ -1,8 +1,11 @@
Changes
=======

1.2dev
------
1.2a1
-----

- Add dst_path and dst_kwds parameters to rasterio's merge tool to allow
results to be written directly to a dataset (#1867).

1.1.8 (2020-10-20)
------------------
Expand Down
2 changes: 1 addition & 1 deletion rasterio/__init__.py
Expand Up @@ -41,7 +41,7 @@ def emit(self, record):
import rasterio.path

__all__ = ['band', 'open', 'pad', 'Env']
__version__ = "1.2dev"
__version__ = "1.2a1"
__gdal_version__ = gdal_version()

# Rasterio attaches NullHandler to the 'rasterio' logger and its
Expand Down
52 changes: 45 additions & 7 deletions rasterio/merge.py
@@ -1,18 +1,17 @@
"""Copy valid pixels from input files to an output file."""

from contextlib import contextmanager
from pathlib import Path
import logging
import math
from pathlib import Path
import warnings

import numpy as np

import rasterio
from rasterio import windows
from rasterio.enums import Resampling
from rasterio.compat import string_types
from rasterio.enums import Resampling
from rasterio import windows
from rasterio.transform import Affine


Expand All @@ -21,9 +20,20 @@
MERGE_METHODS = ('first', 'last', 'min', 'max')


def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10,
indexes=None, output_count=None, resampling=Resampling.nearest,
method='first'):
def merge(
datasets,
bounds=None,
res=None,
nodata=None,
dtype=None,
precision=10,
indexes=None,
output_count=None,
resampling=Resampling.nearest,
method="first",
dst_path=None,
dst_kwds=None,
):
"""Copy valid pixels from input files to an output file.
All files must have the same number of bands, data type, and
Expand Down Expand Up @@ -90,6 +100,11 @@ def function(old_data, new_data, old_nodata, new_nodata, index=None, roff=None,
row offset in base array
coff: int
column offset in base array
dst_path : str or Pathlike, optional
Path of output dataset
dst_kwds : dict, optional
Dictionary of creation options and other paramters that will be
overlaid on the profile of the output dataset.
Returns
-------
Expand Down Expand Up @@ -124,6 +139,7 @@ def nullcontext(obj):
dataset_opener = nullcontext

with dataset_opener(datasets[0]) as first:
first_profile = first.profile
first_res = first.res
nodataval = first.nodatavals[0]
dt = first.dtypes[0]
Expand All @@ -135,6 +151,11 @@ def nullcontext(obj):
else:
src_count = len(indexes)

try:
first_colormap = first.colormap(1)
except ValueError:
first_colormap = None

if not output_count:
output_count = src_count

Expand Down Expand Up @@ -180,6 +201,16 @@ def nullcontext(obj):
dt = dtype
logger.debug("Set dtype: %s", dt)

out_profile = first_profile
out_profile.update(**(dst_kwds or {}))

out_profile["transform"] = output_transform
out_profile["height"] = output_height
out_profile["width"] = output_width
out_profile["count"] = output_count
if nodata is not None:
out_profile["nodata"] = nodata

# create destination array
dest = np.zeros((output_count, output_height, output_width), dtype=dt)

Expand Down Expand Up @@ -296,4 +327,11 @@ def copyto(old_data, new_data, old_nodata, new_nodata, **kwargs):
copyto(region, temp, region_nodata, temp_nodata,
index=idx, roff=roff, coff=coff)

return dest, output_transform
if dst_path is None:
return dest, output_transform

else:
with rasterio.open(dst_path, "w", **out_profile) as dst:
dst.write(dest)
if first_colormap:
dst.write_colormap(1, first_colormap)
29 changes: 3 additions & 26 deletions rasterio/rio/merge.py
Expand Up @@ -3,7 +3,6 @@

import click

import rasterio
from rasterio.enums import Resampling
from rasterio.rio import options
from rasterio.rio.helpers import resolve_inout
Expand Down Expand Up @@ -56,36 +55,14 @@ def merge(ctx, files, output, driver, bounds, res, resampling,
resampling = Resampling[resampling]

with ctx.obj["env"]:
dest, output_transform = merge_tool(
merge_tool(
files,
bounds=bounds,
res=res,
nodata=nodata,
precision=precision,
indexes=(bidx or None),
resampling=resampling,
dst_path=output,
dst_kwds=creation_options,
)

with rasterio.open(files[0]) as first:
profile = first.profile
profile["transform"] = output_transform
profile["height"] = dest.shape[1]
profile["width"] = dest.shape[2]
profile["count"] = dest.shape[0]
profile.pop("driver", None)
if driver:
profile["driver"] = driver
if nodata is not None:
profile["nodata"] = nodata

profile.update(**creation_options)

with rasterio.open(output, "w", **profile) as dst:
dst.write(dest)

# uses the colormap in the first input raster.
try:
colormap = first.colormap(1)
dst.write_colormap(1, colormap)
except ValueError:
pass
13 changes: 13 additions & 0 deletions tests/test_rio_merge.py
Expand Up @@ -606,6 +606,19 @@ def test_merge_pathlib_path(tiffs):
merge(inputs, res=2)


def test_merge_output_dataset(tiffs, tmpdir):
"""Write to an open dataset"""
inputs = [str(x) for x in tiffs.listdir()]
inputs.sort()
output_file = tmpdir.join("output.tif")
merge(inputs, res=2, dst_path=str(output_file), dst_kwds=dict(driver="PNG"))

with rasterio.open(str(output_file)) as result:
assert result.count == 1
assert result.driver == "PNG"
assert result.height == result.width == 2


@fixture(scope='function')
def test_data_dir_resampling(tmpdir):
kwargs = {
Expand Down

0 comments on commit ee6fb20

Please sign in to comment.