Skip to content

Commit

Permalink
BUG: Fix 'rio.set_spatial_dims' so information is preserved with 'rio…
Browse files Browse the repository at this point in the history
…' accesors (#95)
  • Loading branch information
snowman2 committed Mar 6, 2020
1 parent 6c26aac commit fabed24
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 44 deletions.
4 changes: 4 additions & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
History
=======

0.0.23
------
- BUG: Fix 'rio.set_spatial_dims' so information saved with 'rio' accesors (issue #94)

0.0.22
-------
- ENH: Use pyproj.CRS internally to manage GDAL 2/3 transition (issue #92)
Expand Down
139 changes: 95 additions & 44 deletions rioxarray/rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ def add_spatial_ref(in_ds, dst_crs, grid_map_name):

def _add_attrs_proj(new_data_array, src_data_array):
"""Make sure attributes and projection correct"""
# make sure dimension information is preserved
if new_data_array.rio._x_dim is None:
new_data_array.rio._x_dim = src_data_array.rio.x_dim
if new_data_array.rio._y_dim is None:
new_data_array.rio._y_dim = src_data_array.rio.y_dim

# make sure attributes preserved
new_attrs = _generate_attrs(
src_data_array, new_data_array.rio.transform(recalc=True), None
Expand Down Expand Up @@ -319,7 +325,14 @@ def _get_obj(self, inplace):
"""
if inplace:
return self._obj
return self._obj.copy(deep=True)
obj_copy = self._obj.copy(deep=True)
# preserve attribute information
obj_copy.rio._x_dim = self._x_dim
obj_copy.rio._y_dim = self._y_dim
obj_copy.rio._width = self._width
obj_copy.rio._height = self._height
obj_copy.rio._crs = self._crs
return obj_copy

def set_crs(self, input_crs, inplace=True):
"""
Expand Down Expand Up @@ -399,9 +412,12 @@ def write_crs(
):
data_obj[var].rio.update_attrs(
dict(grid_mapping=grid_mapping_name), inplace=True
).rio.set_spatial_dims(
x_dim=self.x_dim, y_dim=self.y_dim, inplace=True
)
data_obj.rio.update_attrs(dict(grid_mapping=grid_mapping_name), inplace=True)
return data_obj
return data_obj.rio.update_attrs(
dict(grid_mapping=grid_mapping_name), inplace=True
)

def set_attrs(self, new_attrs, inplace=False):
"""
Expand All @@ -423,7 +439,8 @@ def set_attrs(self, new_attrs, inplace=False):
data_obj = self._get_obj(inplace=inplace)
# set the attributes
data_obj.attrs = new_attrs
# reset rioxarray properties
# reset rioxarray properties depending
# on attributes to be generated
data_obj.rio._nodata = None
data_obj.rio._crs = None
return data_obj
Expand Down Expand Up @@ -522,6 +539,26 @@ def shape(self):
"""tuple: Returns the shape (width, height)"""
return (self.width, self.height)

def isel_window(self, window):
"""
Use a rasterio.window.Window to select a subset of the data.
Parameters
----------
window: :class:`rasterio.window.Window`
The window of the dataset to read.
Returns
-------
:obj:`xarray.Dataset` | :obj:`xarray.DataArray`: The data in the window.
"""
(row_start, row_stop), (col_start, col_stop) = window.toranges()
row_slice = slice(int(math.floor(row_start)), int(math.ceil(row_stop)))
col_slice = slice(int(math.floor(col_start)), int(math.ceil(col_stop)))
return self._obj.isel(
{self.x_dim: row_slice, self.y_dim: col_slice}
).rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True)


@xarray.register_dataarray_accessor("rio")
class RasterArray(XRasterBase):
Expand Down Expand Up @@ -930,7 +967,9 @@ def slice_xy(self, minx, miny, maxx, maxy):
else:
x_slice = slice(minx, maxx)

return self._obj.sel({self.x_dim: x_slice, self.y_dim: y_slice})
return self._obj.sel(
{self.x_dim: x_slice, self.y_dim: y_slice}
).rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True)

def clip_box(self, minx, miny, maxx, maxy, auto_expand=False, auto_expand_limit=3):
"""Clip the :class:`xarray.DataArray` by a bounding box.
Expand Down Expand Up @@ -1079,24 +1118,6 @@ def clip(self, geometries, crs, all_touched=False, drop=True, invert=False):

return cropped_ds

def isel_window(self, window):
"""
Use a rasterio.window.Window to select a subset of the data.
Parameters
----------
window: :class:`rasterio.window.Window`
The window of the dataset to read.
Returns
-------
:obj:`xarray.Dataset` | :obj:`xarray.DataArray`: The data in the window.
"""
(row_start, row_stop), (col_start, col_stop) = window.toranges()
row_slice = slice(int(math.floor(row_start)), int(math.ceil(row_stop)))
col_slice = slice(int(math.floor(col_start)), int(math.ceil(col_stop)))
return self._obj.isel({self.x_dim: row_slice, self.y_dim: col_slice})

def _interpolate_na(self, src_data, method="nearest"):
"""
This method uses scipy.interpolate.griddata to interpolate missing data.
Expand Down Expand Up @@ -1340,11 +1361,15 @@ def reproject(
"""
resampled_dataset = xarray.Dataset(attrs=self._obj.attrs)
for var in self.vars:
resampled_dataset[var] = self._obj[var].rio.reproject(
dst_crs,
resolution=resolution,
dst_affine_width_height=dst_affine_width_height,
resampling=resampling,
resampled_dataset[var] = (
self._obj[var]
.rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True)
.rio.reproject(
dst_crs,
resolution=resolution,
dst_affine_width_height=dst_affine_width_height,
resampling=resampling,
)
)
return resampled_dataset

Expand Down Expand Up @@ -1375,10 +1400,14 @@ def reproject_match(self, match_data_array, resampling=Resampling.nearest):
"""
resampled_dataset = xarray.Dataset(attrs=self._obj.attrs)
for var in self.vars:
resampled_dataset[var] = self._obj[var].rio.reproject_match(
match_data_array, resampling=resampling
resampled_dataset[var] = (
self._obj[var]
.rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True)
.rio.reproject_match(match_data_array, resampling=resampling)
)
return resampled_dataset
return resampled_dataset.rio.set_spatial_dims(
x_dim=self.x_dim, y_dim=self.y_dim, inplace=True
)

def clip_box(self, minx, miny, maxx, maxy, auto_expand=False, auto_expand_limit=3):
"""Clip the :class:`xarray.Dataset` by a bounding box.
Expand Down Expand Up @@ -1409,15 +1438,21 @@ def clip_box(self, minx, miny, maxx, maxy, auto_expand=False, auto_expand_limit=
"""
clipped_dataset = xarray.Dataset(attrs=self._obj.attrs)
for var in self.vars:
clipped_dataset[var] = self._obj[var].rio.clip_box(
minx,
miny,
maxx,
maxy,
auto_expand=auto_expand,
auto_expand_limit=auto_expand_limit,
clipped_dataset[var] = (
self._obj[var]
.rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True)
.rio.clip_box(
minx,
miny,
maxx,
maxy,
auto_expand=auto_expand,
auto_expand_limit=auto_expand_limit,
)
)
return clipped_dataset
return clipped_dataset.rio.set_spatial_dims(
x_dim=self.x_dim, y_dim=self.y_dim, inplace=True
)

def clip(self, geometries, crs, all_touched=False, drop=True, invert=False):
"""
Expand Down Expand Up @@ -1467,10 +1502,20 @@ def clip(self, geometries, crs, all_touched=False, drop=True, invert=False):
"""
clipped_dataset = xarray.Dataset(attrs=self._obj.attrs)
for var in self.vars:
clipped_dataset[var] = self._obj[var].rio.clip(
geometries, crs=crs, all_touched=all_touched, drop=drop, invert=invert
clipped_dataset[var] = (
self._obj[var]
.rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True)
.rio.clip(
geometries,
crs=crs,
all_touched=all_touched,
drop=drop,
invert=invert,
)
)
return clipped_dataset
return clipped_dataset.rio.set_spatial_dims(
x_dim=self.x_dim, y_dim=self.y_dim, inplace=True
)

def interpolate_na(self, method="nearest"):
"""
Expand All @@ -1488,8 +1533,14 @@ def interpolate_na(self, method="nearest"):
"""
interpolated_dataset = xarray.Dataset(attrs=self._obj.attrs)
for var in self.vars:
interpolated_dataset[var] = self._obj[var].rio.interpolate_na(method=method)
return interpolated_dataset
interpolated_dataset[var] = (
self._obj[var]
.rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True)
.rio.interpolate_na(method=method)
)
return interpolated_dataset.rio.set_spatial_dims(
x_dim=self.x_dim, y_dim=self.y_dim, inplace=True
)

def to_raster(
self,
Expand Down
143 changes: 143 additions & 0 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from functools import partial

Expand Down Expand Up @@ -1329,3 +1330,145 @@ def test_write_pyproj_crs_dataset():
test_ds = test_ds.rio.write_crs(pCRS(4326))
assert test_ds.attrs["grid_mapping"] == "spatial_ref"
assert test_ds.rio.crs.to_epsg() == 4326


def test_nonstandard_dims_clip__dataset():
with open(os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim_geom.json")) as ndj:
geom = json.load(ndj)
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim.nc")
) as xds:
clipped = (
xds.rio.set_spatial_dims(x_dim="lon", y_dim="lat")
.rio.write_crs("EPSG:4326")
.rio.clip([geom], "EPSG:4326")
)
assert clipped.rio.width == 6
assert clipped.rio.height == 5


def test_nonstandard_dims_clip__array():
with open(os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim_geom.json")) as ndj:
geom = json.load(ndj)
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim.nc")
) as xds:
clipped = (
xds.analysed_sst.rio.set_spatial_dims(x_dim="lon", y_dim="lat")
.rio.write_crs("EPSG:4326")
.rio.clip([geom], "EPSG:4326")
)
assert clipped.rio.width == 6
assert clipped.rio.height == 5


def test_nonstandard_dims_clip_box__dataset():
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim.nc")
) as xds:
clipped = (
xds.rio.set_spatial_dims(x_dim="lon", y_dim="lat")
.rio.write_crs("EPSG:4326")
.rio.clip_box(
-70.51367964678269,
-23.780199727400767,
-70.44589567737998,
-23.71896017814794,
)
)
assert clipped.rio.width == 7
assert clipped.rio.height == 7


def test_nonstandard_dims_clip_box_array():
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim.nc")
) as xds:
clipped = (
xds.analysed_sst.rio.set_spatial_dims(x_dim="lon", y_dim="lat")
.rio.write_crs("EPSG:4326")
.rio.clip_box(
-70.51367964678269,
-23.780199727400767,
-70.44589567737998,
-23.71896017814794,
)
)
assert clipped.rio.width == 7
assert clipped.rio.height == 7


def test_nonstandard_dims_reproject__dataset():
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim.nc")
) as xds:
xds = xds.rio.set_spatial_dims(x_dim="lon", y_dim="lat").rio.write_crs(
"EPSG:4326"
)
reprojected = xds.rio.reproject("epsg:3857")
assert reprojected.rio.width == 11
assert reprojected.rio.height == 11
assert reprojected.rio.crs.to_epsg() == 3857


def test_nonstandard_dims_reproject__array():
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim.nc")
) as xds:
reprojected = (
xds.analysed_sst.rio.set_spatial_dims(x_dim="lon", y_dim="lat")
.rio.write_crs("EPSG:4326")
.rio.reproject("epsg:3857")
)
assert reprojected.rio.width == 11
assert reprojected.rio.height == 11
assert reprojected.rio.crs.to_epsg() == 3857


def test_nonstandard_dims_interpolate_na__dataset():
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim.nc")
) as xds:
reprojected = (
xds.rio.set_spatial_dims(x_dim="lon", y_dim="lat")
.rio.write_crs("EPSG:4326")
.rio.interpolate_na()
)
assert reprojected.rio.width == 11
assert reprojected.rio.height == 11


def test_nonstandard_dims_interpolate_na__array():
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim.nc")
) as xds:
reprojected = (
xds.analysed_sst.rio.set_spatial_dims(x_dim="lon", y_dim="lat")
.rio.write_crs("EPSG:4326")
.rio.interpolate_na()
)
assert reprojected.rio.width == 11
assert reprojected.rio.height == 11


def test_nonstandard_dims_write_nodata__array():
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim.nc")
) as xds:
reprojected = xds.analysed_sst.rio.set_spatial_dims(
x_dim="lon", y_dim="lat"
).rio.write_nodata(-999)
assert reprojected.rio.width == 11
assert reprojected.rio.height == 11
assert reprojected.rio.nodata == -999


def test_nonstandard_dims_isel_window():
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "nonstandard_dim.nc")
) as xds:
reprojected = xds.rio.set_spatial_dims(
x_dim="lon", y_dim="lat"
).rio.isel_window(Window.from_slices(slice(5), slice(5)))
assert reprojected.rio.width == 5
assert reprojected.rio.height == 5
Binary file added test/test_data/input/nonstandard_dim.nc
Binary file not shown.
Loading

0 comments on commit fabed24

Please sign in to comment.