Skip to content

Commit

Permalink
BUG: Use internal reprojection as engine for resampling window in merge
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Jun 12, 2020
1 parent 74bb04e commit 19e56b5
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[settings]
line_length=88
multi_line_output=3
known_third_party=PIL,affine,dask,mock,numpy,pyproj,pytest,rasterio,scipy,setuptools,xarray
known_third_party=affine,dask,mock,numpy,pyproj,pytest,rasterio,scipy,setuptools,xarray
known_first_party=rioxarray,test
include_trailing_comma=true
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ History
- Fix ``RasterioDeprecationWarning`` (pull #117)
- BUG: Make rio.shape order same as rasterio dataset shape (height, width) (pull #121)
- Fix open_rasterio() for WarpedVRT with specified src_crs (pydata/xarray/pull/4104 & pull 120)
- BUG: Use internal reprojection as engine for resampling window in merge (pull #123)

0.0.26
------
Expand Down
5 changes: 1 addition & 4 deletions rioxarray/_show_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def _get_deps_info():
deps_info: dict
version information on relevant Python libraries
"""
# Note: PIL is the module name, howerver pillow is the dependency
deps = ["PIL", "scipy", "pyproj"]
deps = ["scipy", "pyproj"]

def get_version(module):
try:
Expand All @@ -70,8 +69,6 @@ def get_version(module):
else:
mod = importlib.import_module(modname)
ver = get_version(mod)
# use PIL only to get version information
modname = "pillow" if modname == "PIL" else modname
deps_info[modname] = ver
except ImportError:
deps_info[modname] = None
Expand Down
27 changes: 11 additions & 16 deletions rioxarray/merge.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Callable, Iterable, Tuple, Union

import numpy
from PIL import Image
from rasterio.merge import merge as _rio_merge
from xarray import DataArray, Dataset

Expand Down Expand Up @@ -30,35 +29,31 @@ def read(self, window, out_shape, *args, **kwargs) -> numpy.ma.array:
"""
This method is meant to be used by the rasterio.merge.merge function.
"""
data_window = self._xds.rio.isel_window(window).values
data_window = self._xds.rio.isel_window(window)
if data_window.shape != out_shape:
# in this section, the data is geographically the same
# however it is not the same dimensions as requested
# so need to resample to the reqwuested shape
if len(data_window.shape) == 3:
new_data = numpy.empty(out_shape, dtype=data_window.dtype)
count, height, width = out_shape
for iii in range(data_window.shape[0]):
new_data[iii] = numpy.array(
Image.fromarray(data_window[iii]).resize((width, height))
)
data_window = new_data
# so need to resample to the requested shape
if len(out_shape) == 3:
_, out_height, out_width = out_shape
else:
data_window = numpy.array(
Image.fromarray(data_window).resize(out_shape)
)
out_height, out_width = out_shape
data_window = self._xds.rio.reproject(
self._xds.rio.crs,
dst_affine_width_height=(self.transform, out_width, out_height),
)

nodata = self.nodatavals[0]
mask = False
fill_value = None
if numpy.isnan(nodata):
if nodata is not None and numpy.isnan(nodata):
mask = numpy.isnan(data_window)
elif nodata is not None:
mask = data_window == nodata
fill_value = nodata

return numpy.ma.array(
data_window, mask=mask, fill_value=fill_value, dtype=self._xds.dtype
data_window, mask=mask, fill_value=fill_value, dtype=self.dtypes[0]
)


Expand Down
14 changes: 3 additions & 11 deletions test/integration/test_integration_merge.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os
from distutils.version import LooseVersion

import pytest
from numpy import nansum
from numpy.testing import assert_almost_equal
from rasterio import gdal_version

from rioxarray import open_rasterio
from rioxarray.merge import merge_arrays, merge_datasets
Expand Down Expand Up @@ -88,10 +86,7 @@ def test_merge_arrays__res():
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == rds.attrs
if LooseVersion(gdal_version()) >= LooseVersion("2.4.4"):
assert_almost_equal(nansum(merged), 13754521.430030823)
else:
assert_almost_equal(nansum(merged), 13767944)
assert_almost_equal(nansum(merged), 13556564)


@pytest.mark.xfail(os.name == "nt", reason="On windows the merged data is different.")
Expand Down Expand Up @@ -147,7 +142,7 @@ def test_merge_datasets():
base_attrs = dict(rds.attrs)
base_attrs["grid_mapping"] = "spatial_ref"
assert merged.attrs == base_attrs
assert_almost_equal(merged[data_var].sum(), 4539265823591471)
assert_almost_equal(merged[data_var].sum(), 4543446965182987)


@pytest.mark.xfail(os.name == "nt", reason="On windows the merged data is different.")
Expand Down Expand Up @@ -193,7 +188,4 @@ def test_merge_datasets__res():
base_attrs = dict(rds.attrs)
base_attrs["grid_mapping"] = "spatial_ref"
assert merged.attrs == base_attrs
if LooseVersion(gdal_version()) >= LooseVersion("2.4.4"):
assert_almost_equal(merged[data_var].sum(), 974565505482489)
else:
assert_almost_equal(merged[data_var].sum(), 974565970607345)
assert_almost_equal(merged[data_var].sum(), 973667940761024)
1 change: 0 additions & 1 deletion test/unit/test_show_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def test_get_sys_info():
def test_get_deps_info():
deps_info = _get_deps_info()

assert "pillow" in deps_info
assert "scipy" in deps_info
assert "pyproj" in deps_info

Expand Down

0 comments on commit 19e56b5

Please sign in to comment.