Skip to content

Commit

Permalink
updated to support 3D datasets on export
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Jun 25, 2019
1 parent 60ad031 commit fc74a2f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 15 deletions.
8 changes: 8 additions & 0 deletions rioxarray/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@ class OneDimensionalRaster(RioXarrayError):

class SingleVariableDataset(RioXarrayError):
"""This is for when you have a dataset with a single variable."""


class TooManyDimensions(RioXarrayError):
"""This is raised when there are more dimensions than is supported by the method"""


class InvalidDimensionOrder(RioXarrayError):
"""This is raised when there the dimensions are not ordered correctly."""
58 changes: 43 additions & 15 deletions rioxarray/rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from rasterio.features import geometry_mask
from scipy.interpolate import griddata

from rioxarray.exceptions import NoDataInBounds, OneDimensionalRaster
from rioxarray.exceptions import (
InvalidDimensionOrder,
NoDataInBounds,
OneDimensionalRaster,
TooManyDimensions,
)
from rioxarray.crs import crs_to_wkt

FILL_VALUE_NAMES = ("_FillValue", "missing_value", "fill_value")
Expand Down Expand Up @@ -327,6 +332,32 @@ def _int_bounds(self):
bottom = float(self._obj[self.y_dim][-1])
return left, bottom, right, top

def _check_dimensions(self):
"""
This function validates that the dimensions 2D/3D and
they are are in the proper order.
Returns:
--------
str or None: Name extra dimension.
"""
extra_dims = list(set(list(self._obj.dims)) - set([self.x_dim, self.y_dim]))
if len(extra_dims) > 1:
raise TooManyDimensions("Only 2D and 3D data arrays supported.")
elif extra_dims and self._obj.dims != (extra_dims[0], self.y_dim, self.x_dim):
raise InvalidDimensionOrder(
"Invalid dimension order. Expected order: {}".format(
(extra_dims[0], self.y_dim, self.x_dim)
)
)
elif not extra_dims and self._obj.dims != (self.y_dim, self.x_dim):
raise InvalidDimensionOrder(
"Invalid dimension order. Expected order: {}".format(
(self.y_dim, self.x_dim)
)
)
return extra_dims[0] if extra_dims else None

def bounds(self, recalc=False):
"""Determine the bounds of the `xarray.DataArray`
Expand Down Expand Up @@ -438,17 +469,13 @@ def reproject(
dst_affine, dst_width, dst_height = _make_dst_affine(
self._obj, self.crs, dst_crs, resolution
)
extra_dims = list(set(list(self._obj.dims)) - set([self.x_dim, self.y_dim]))
if len(extra_dims) > 1:
raise RuntimeError("Reproject only supports 2D and 3D datasets.")
if extra_dims:
assert self._obj.dims == (extra_dims[0], self.y_dim, self.x_dim)
extra_dim = self._check_dimensions()
if extra_dim:
dst_data = np.zeros(
(self._obj[extra_dims[0]].size, dst_height, dst_width),
(self._obj[extra_dim].size, dst_height, dst_width),
dtype=self._obj.dtype.type,
)
else:
assert self._obj.dims == (self.y_dim, self.x_dim)
dst_data = np.zeros((dst_height, dst_width), dtype=self._obj.dtype.type)

try:
Expand Down Expand Up @@ -742,13 +769,10 @@ def interpolate_na(self, method="nearest"):
:class:`xarray.DataArray`: An interpolated :class:`xarray.DataArray` object.
"""
extra_dims = list(set(list(self._obj.dims)) - set([self.x_dim, self.y_dim]))
if len(extra_dims) > 1:
raise RuntimeError("Interpolate only supports 2D and 3D datasets.")
if extra_dims:
assert self._obj.dims == (extra_dims[0], self.y_dim, self.x_dim)
extra_dim = self._check_dimensions()
if extra_dim:
interp_data = []
for _, sub_xds in self._obj.groupby(extra_dims[0]):
for _, sub_xds in self._obj.groupby(extra_dim):
interp_data.append(
self._interpolate_na(sub_xds.load().data, method=method)
)
Expand Down Expand Up @@ -786,13 +810,17 @@ def to_raster(self, raster_path, driver="GTiff", dtype=None):
"""
width, height = self.shape
dtype = str(self._obj.dtype) if dtype is None else dtype
extra_dim = self._check_dimensions()
count = 1
if extra_dim is not None:
count = self._obj[extra_dim].size
with rasterio.open(
raster_path,
"w",
driver=driver,
height=int(height),
width=int(width),
count=len(self._obj.dims),
count=count,
dtype=dtype,
crs=self.crs,
transform=self.transform(recalc=True),
Expand Down
15 changes: 15 additions & 0 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,3 +688,18 @@ def test_to_raster(tmpdir):
assert_array_equal(rds.transform, xds.rio.transform())
assert_array_equal(rds.nodata, xds.rio.nodata)
assert_array_equal(rds.read(1), xds.values)


def test_to_raster_3d(tmpdir):
tmp_raster = tmpdir.join("planet_3d_raster.tiff")
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "PLANET_SCOPE_3D.nc"), autoclose=True
) as mda:
mda.green.rio.to_raster(str(tmp_raster))
xds = mda.green.copy()

with rasterio.open(str(tmp_raster)) as rds:
assert rds.crs == xds.rio.crs
assert_array_equal(rds.transform, xds.rio.transform())
assert_array_equal(rds.nodata, xds.rio.nodata)
assert_array_equal(rds.read(), xds.values)

0 comments on commit fc74a2f

Please sign in to comment.