Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow destination shape to be specified in reproject() #116

Merged
merged 9 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
History
=======

Unreleased
----------

- ENH: Added optional `shape` argument to `rio.reproject` (pull #116)

0.0.26
------
- ENH: Added :func:`rioxarray.show_versions` (issue #106)
Expand Down
37 changes: 34 additions & 3 deletions rioxarray/rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,29 @@ def _make_coords(src_data_array, dst_affine, dst_width, dst_height, dst_crs):
return add_xy_grid_meta(new_coords)


def _make_dst_affine(src_data_array, src_crs, dst_crs, dst_resolution=None):
def _make_dst_affine(
src_data_array, src_crs, dst_crs, dst_resolution=None, dst_shape=None
):
"""Determine the affine of the new projected `xarray.DataArray`"""
src_bounds = src_data_array.rio.bounds()
src_width, src_height = src_data_array.rio.shape
dst_height, dst_width = dst_shape if dst_shape is not None else (None, None)
resolution_or_width_height = {
snowman2 marked this conversation as resolved.
Show resolved Hide resolved
k: v
for k, v in [
("resolution", dst_resolution),
("dst_height", dst_height),
("dst_width", dst_width),
]
if v is not None
}
dst_affine, dst_width, dst_height = rasterio.warp.calculate_default_transform(
src_crs, dst_crs, src_width, src_height, *src_bounds, resolution=dst_resolution
src_crs,
dst_crs,
src_width,
src_height,
*src_bounds,
**resolution_or_width_height,
)
return dst_affine, dst_width, dst_height

Expand Down Expand Up @@ -851,6 +868,7 @@ def reproject(
self,
dst_crs,
resolution=None,
shape=None,
dst_affine_width_height=None,
resampling=Resampling.nearest,
):
Expand All @@ -864,13 +882,18 @@ def reproject(
a 'crs' attribute to be set containing a valid CRS.
If using a WKT (e.g. from spatiareference.org), make sure it is an OGC WKT.

.. versionadded:: 0.0.27 shape

Parameters
----------
dst_crs: str
OGC WKT string or Proj.4 string.
resolution: float or tuple(float, float), optional
Size of a destination pixel in destination projection units
(e.g. degrees or metres).
shape: tuple(int, int), optional
Shape of the destination in pixels (dst_height, dst_width). Cannot be used
together with resolution.
dst_affine_width_height: tuple(dst_affine, dst_width, dst_height), optional
Tuple with the destination affine, width, and height.
resampling: Resampling method, optional
Expand All @@ -882,6 +905,8 @@ def reproject(
:class:`xarray.DataArray`: A reprojected DataArray.

"""
if resolution is not None and shape is not None:
raise RioXarrayError("resolution and shape cannot be used together.")
if self.crs is None:
raise MissingCRS(
"CRS not found. Please set the CRS with 'set_crs()' or 'write_crs()'."
Expand All @@ -892,7 +917,7 @@ def reproject(
dst_affine, dst_width, dst_height = dst_affine_width_height
else:
dst_affine, dst_width, dst_height = _make_dst_affine(
self._obj, self.crs, dst_crs, resolution
self._obj, self.crs, dst_crs, resolution, shape
)
extra_dim = self._check_dimensions()
if extra_dim:
Expand Down Expand Up @@ -1385,6 +1410,7 @@ def reproject(
self,
dst_crs,
resolution=None,
shape=None,
dst_affine_width_height=None,
resampling=Resampling.nearest,
):
Expand All @@ -1396,6 +1422,7 @@ def reproject(
a 'crs' attribute to be set containing a valid CRS.
If using a WKT (e.g. from spatiareference.org), make sure it is an OGC WKT.

.. versionadded:: 0.0.27 shape

Parameters
----------
Expand All @@ -1404,6 +1431,9 @@ def reproject(
resolution: float or tuple(float, float), optional
Size of a destination pixel in destination projection units
(e.g. degrees or metres).
shape: tuple(int, int), optional
Shape of the destination in pixels (dst_height, dst_width). Cannot be used
together with resolution.
dst_affine_width_height: tuple(dst_affine, dst_width, dst_height), optional
Tuple with the destination affine, width, and height.
resampling: Resampling method, optional
Expand All @@ -1423,6 +1453,7 @@ def reproject(
.rio.reproject(
dst_crs,
resolution=resolution,
shape=shape,
dst_affine_width_height=dst_affine_width_height,
resampling=resampling,
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_version():
"nbsphinx",
"sphinx_rtd_theme",
"black",
"flake8",
"flake8==3.7",
"pylint",
"isort",
"pre-commit",
Expand Down
28 changes: 28 additions & 0 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,23 @@ def test_transform_bounds():
)


def test_reproject_with_shape(modis_reproject):
new_shape = (9, 10)
mask_args = (
dict(masked=False)
if "open_rasterio" in str(modis_reproject["open"])
else dict(mask_and_scale=False)
)
with modis_reproject["open"](modis_reproject["input"], **mask_args) as mda:
mds_repr = mda.rio.reproject(modis_reproject["to_proj"], shape=new_shape)
# test
if hasattr(mds_repr, "variables"):
for var in mds_repr.rio.vars:
assert mds_repr[var].shape == new_shape
else:
assert mds_repr.shape == new_shape


def test_reproject(modis_reproject):
mask_args = (
dict(masked=False)
Expand Down Expand Up @@ -1256,6 +1273,17 @@ def test_reproject_missing_crs():
test_da.rio.reproject(4326)


def test_reproject_resolution_and_shape():
test_da = xarray.DataArray(
numpy.zeros((5, 5)),
dims=("y", "x"),
coords={"y": numpy.arange(1, 6), "x": numpy.arange(2, 7)},
attrs={"crs": "+init=epsg:3857"},
)
with pytest.raises(RioXarrayError):
test_da.rio.reproject(4326, resolution=1, shape=(1, 1))


class CustomCRS(object):
@property
def wkt(self):
Expand Down