Skip to content

Commit

Permalink
ENH: Add support for merging datasets with different CRS (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Nov 10, 2020
1 parent 388bf90 commit 82426bd
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History
Latest
------
- ENH: Added `rio.estimate_utm_crs` (issue #181)
- ENH: Add support for merging datasets with different CRS (issue #173)

0.1.1
------
Expand Down
53 changes: 40 additions & 13 deletions rioxarray/merge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, Iterable, Tuple, Union
from typing import Callable, Iterable, Optional, Tuple, Union

import numpy
from rasterio.crs import CRS
from rasterio.merge import merge as _rio_merge
from xarray import DataArray, Dataset

Expand Down Expand Up @@ -60,11 +61,12 @@ def read(self, window, out_shape, *args, **kwargs) -> numpy.ma.array:

def merge_arrays(
dataarrays: Iterable[DataArray],
bounds: Union[Tuple, None] = None,
res: Union[Tuple, None] = None,
nodata: Union[float, None] = None,
precision: Union[float, None] = None,
bounds: Optional[Tuple] = None,
res: Optional[Tuple] = None,
nodata: Optional[float] = None,
precision: Optional[float] = None,
method: Union[str, Callable, None] = None,
crs: Optional[CRS] = None,
parse_coordinates: bool = True,
) -> DataArray:
"""
Expand All @@ -73,6 +75,8 @@ def merge_arrays(
Uses rasterio.merge.merge:
https://rasterio.readthedocs.io/en/stable/api/rasterio.merge.html#rasterio.merge.merge
.. versionadded:: 0.2 crs
Parameters
----------
dataarrays: list
Expand All @@ -93,6 +97,8 @@ def merge_arrays(
Number of decimal points of precision when computing inverse transform.
method: str or callable, optional
See rasterio docs.
crs: rasterio.crs.CRS, optional
Output CRS. If not set, the CRS of the first DataArray is used.
parse_coordinates: bool, optional
If False, it will disable loading spatial coordinates.
Expand All @@ -105,12 +111,27 @@ def merge_arrays(
input_kwargs = dict(
bounds=bounds, res=res, nodata=nodata, precision=precision, method=method
)
if crs is None:
crs = dataarrays[0].rio.crs
if res is None:
res = tuple(abs(res_val) for res_val in dataarrays[0].rio.resolution())
rioduckarrays = []
for dataarray in dataarrays:
da_res = tuple(abs(res_val) for res_val in dataarray.rio.resolution())
if da_res != res or dataarray.rio.crs != crs:
rioduckarrays.append(
RasterioDatasetDuck(
dataarray.rio.reproject(dst_crs=crs, resolution=res)
)
)
else:
rioduckarrays.append(RasterioDatasetDuck(dataarray))
merged_data, merged_transform = _rio_merge(
[RasterioDatasetDuck(dataarray) for dataarray in dataarrays],
rioduckarrays,
**{key: val for key, val in input_kwargs.items() if val is not None},
)
merged_shape = merged_data.shape
representative_array = dataarrays[0]
representative_array = rioduckarrays[0]._xds
if parse_coordinates:
coords = _make_coords(
representative_array, merged_transform, merged_shape[-1], merged_shape[-2]
Expand All @@ -120,7 +141,7 @@ def merge_arrays(

out_attrs = representative_array.attrs
xda = DataArray(
name=dataarrays[0].name,
name=representative_array.name,
data=merged_data,
coords=coords,
dims=tuple(representative_array.dims),
Expand All @@ -135,18 +156,21 @@ def merge_arrays(

def merge_datasets(
datasets: Iterable[Dataset],
bounds: Union[Tuple, None] = None,
res: Union[Tuple, None] = None,
nodata: Union[float, None] = None,
precision: Union[float, None] = None,
bounds: Optional[Tuple] = None,
res: Optional[Tuple] = None,
nodata: Optional[float] = None,
precision: Optional[float] = None,
method: Union[str, Callable, None] = None,
crs: Optional[CRS] = None,
) -> DataArray:
"""
Merge datasets geospatially.
Uses rasterio.merge.merge:
https://rasterio.readthedocs.io/en/stable/api/rasterio.merge.html#rasterio.merge.merge
.. versionadded:: 0.2 crs
Parameters
----------
datasets: list
Expand All @@ -167,6 +191,8 @@ def merge_datasets(
Number of decimal points of precision when computing inverse transform.
method: str or callable, optional
See rasterio docs.
crs: rasterio.crs.CRS, optional
Output CRS. If not set, the CRS of the first DataArray is used.
Returns
-------
Expand All @@ -184,6 +210,7 @@ def merge_datasets(
nodata=nodata,
precision=precision,
method=method,
crs=crs,
parse_coordinates=False,
)
data_var = list(representative_ds.data_vars)[0]
Expand All @@ -197,5 +224,5 @@ def merge_datasets(
),
attrs=representative_ds.attrs,
)
xds.rio.write_crs(representative_ds.rio.crs, inplace=True)
xds.rio.write_crs(merged_data[data_var].rio.crs, inplace=True)
return xds
5 changes: 5 additions & 0 deletions rioxarray/rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import copy
import math
import warnings
from typing import Iterable
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -176,6 +177,10 @@ def _make_dst_affine(
src_bounds = src_data_array.rio.bounds()
src_height, src_width = src_data_array.rio.shape
dst_height, dst_width = dst_shape if dst_shape is not None else (None, None)
if isinstance(dst_resolution, Iterable):
dst_resolution = tuple(abs(res_val) for res_val in dst_resolution)
elif dst_resolution is not None:
dst_resolution = abs(dst_resolution)
resolution_or_width_height = {
k: v
for k, v in [
Expand Down
64 changes: 59 additions & 5 deletions test/integration/test_integration_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,58 @@ def test_merge_arrays():
assert_almost_equal(merged.sum(), 11368261)


@pytest.mark.parametrize("dataset", [True, False])
def test_merge__different_crs(dataset):
dem_test = os.path.join(TEST_INPUT_DATA_DIR, "MODIS_ARRAY.nc")
with open_rasterio(dem_test) as rds:
crs = rds.rio.crs
if dataset:
rds = rds.to_dataset()
arrays = [
rds.isel(x=slice(100), y=slice(100)).rio.reproject("EPSG:3857"),
rds.isel(x=slice(100, 200), y=slice(100, 200)),
rds.isel(x=slice(100), y=slice(100, 200)),
rds.isel(x=slice(100, 200), y=slice(100)),
]
if dataset:
merged = merge_datasets(arrays, crs=crs)
else:
merged = merge_arrays(arrays, crs=crs)
assert_almost_equal(
merged.rio.bounds(),
(-7300984.0238134, 5003618.5908794, -7223500.6583578, 5050108.6101528),
)
assert_almost_equal(
tuple(merged.rio.transform()),
(
553.4526103969893,
0.0,
-7300984.023813409,
0.0,
-553.4526103969796,
5050108.610152751,
0.0,
0.0,
1.0,
),
)
assert merged.rio.shape == (84, 140)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
if dataset:
assert merged.attrs == {"grid_mapping": "spatial_ref"}
assert_almost_equal(merged[merged.rio.vars[0]].sum(), -131013894)
else:
assert merged.attrs == {
"_FillValue": -28672,
"add_offset": 0.0,
"grid_mapping": "spatial_ref",
"scale_factor": 1.0,
}
assert_almost_equal(merged.sum(), -131013894)


def test_merge_arrays__res():
dem_test = os.path.join(TEST_INPUT_DATA_DIR, "MODIS_ARRAY.nc")
with open_rasterio(dem_test, masked=True) as rds:
Expand All @@ -70,20 +122,22 @@ def test_merge_arrays__res():

assert_almost_equal(
merged.rio.bounds(),
(-7274009.649486291, 5003608.61015275, -7227509.649486291, 5050108.61015275),
(-7274009.6494863, 5003308.6101528, -7227209.6494863, 5050108.6101528),
)
assert_almost_equal(
tuple(merged.rio.transform()),
(300.0, 0.0, -7274009.649486291, 0.0, -300.0, 5050108.61015275, 0.0, 0.0, 1.0),
)
assert merged.rio._cached_transform() == merged.rio.transform()
assert merged.rio.shape == (155, 155)
assert merged.rio.shape == (156, 156)
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert_almost_equal(merged.attrs.pop("_FillValue"), rds.attrs.pop("_FillValue"))
assert merged.attrs == rds.attrs
assert_almost_equal(nansum(merged), 13556564)
compare_attrs = dict(rds.attrs)
compare_attrs.pop("crs")
assert merged.attrs == compare_attrs
assert_almost_equal(nansum(merged), 13760565)


@pytest.mark.xfail(os.name == "nt", reason="On windows the merged data is different.")
Expand Down Expand Up @@ -185,4 +239,4 @@ def test_merge_datasets__res():
base_attrs = dict(rds.attrs)
base_attrs["grid_mapping"] = "spatial_ref"
assert merged.attrs == base_attrs
assert_almost_equal(merged[data_var].sum(), 973667940761024)
assert_almost_equal(merged[data_var].sum(), 974566547463955)

0 comments on commit 82426bd

Please sign in to comment.