Skip to content

Commit

Permalink
BUG:dataset: Fix writing tags for bands & prevent overwriting long_na…
Browse files Browse the repository at this point in the history
…me attribute
  • Loading branch information
snowman2 committed Dec 8, 2022
1 parent b186d0c commit d7a1cc7
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 43 deletions.
2 changes: 2 additions & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ History

Latest
------
- BUG:dataset: Fix writing tags for bands (issue #615)
- BUG:dataset: prevent overwriting long_name attribute (pull #616)

0.13.1
------
Expand Down
9 changes: 7 additions & 2 deletions rioxarray/raster_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,18 +497,23 @@ def to_raster(
"""
variable_dim = f"band_{uuid4()}"
data_array = self._obj.to_array(dim=variable_dim)
# write data array names to raster
data_array.attrs["long_name"] = data_array[variable_dim].values.tolist()
# ensure raster metadata preserved
scales = []
offsets = []
nodatavals = []
band_tags = []
long_name = []
for data_var in data_array[variable_dim].values:
scales.append(self._obj[data_var].attrs.get("scale_factor", 1.0))
offsets.append(self._obj[data_var].attrs.get("add_offset", 0.0))
long_name.append(self._obj[data_var].attrs.get("long_name", data_var))
nodatavals.append(self._obj[data_var].rio.nodata)
band_tags.append(self._obj[data_var].attrs.copy())
data_array.attrs["scales"] = scales
data_array.attrs["offsets"] = offsets
data_array.attrs["band_tags"] = band_tags
data_array.attrs["long_name"] = long_name

nodata = nodatavals[0]
if (
all(nodataval == nodata for nodataval in nodatavals)
Expand Down
67 changes: 42 additions & 25 deletions rioxarray/raster_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,10 @@ def is_dask_collection(_) -> bool: # type: ignore
# Note: transform & crs are removed in write_transform/write_crs


def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
def _write_tags(raster_handle, tags):
"""
Write the metadata stored in the xarray object to raster metadata
Write tags to raster dataset
"""
tags = (
xarray_dataset.attrs.copy()
if tags is None
else {**xarray_dataset.attrs, **tags}
)

# write scales and offsets
try:
raster_handle.scales = tags["scales"]
except KeyError:
scale_factor = tags.get(
"scale_factor", xarray_dataset.encoding.get("scale_factor")
)
if scale_factor is not None:
raster_handle.scales = (scale_factor,) * raster_handle.count
try:
raster_handle.offsets = tags["offsets"]
except KeyError:
add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset"))
if add_offset is not None:
raster_handle.offsets = (add_offset,) * raster_handle.count

# filter out attributes that should be written in a different location
skip_tags = (
UNWANTED_RIO_ATTRS
Expand All @@ -80,10 +58,19 @@ def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
# in this case, it will be stored in the raster description
if not isinstance(tags.get("long_name"), str):
skip_tags += ("long_name",)
band_tags = tags.pop("band_tags", [])
tags = {key: value for key, value in tags.items() if key not in skip_tags}
raster_handle.update_tags(**tags)

# write band name information
if isinstance(band_tags, list):
for iii, band_tag in enumerate(band_tags):
raster_handle.update_tags(iii + 1, **band_tag)


def _write_band_description(raster_handle, xarray_dataset):
"""
Write band descriptions using the long name
"""
long_name = xarray_dataset.attrs.get("long_name")
if isinstance(long_name, (tuple, list)):
if len(long_name) != raster_handle.count:
Expand All @@ -100,6 +87,36 @@ def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
raster_handle.set_band_description(iii + 1, band_description)


def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
"""
Write the metadata stored in the xarray object to raster metadata
"""
tags = (
xarray_dataset.attrs.copy()
if tags is None
else {**xarray_dataset.attrs, **tags}
)

# write scales and offsets
try:
raster_handle.scales = tags["scales"]
except KeyError:
scale_factor = tags.get(
"scale_factor", xarray_dataset.encoding.get("scale_factor")
)
if scale_factor is not None:
raster_handle.scales = (scale_factor,) * raster_handle.count
try:
raster_handle.offsets = tags["offsets"]
except KeyError:
add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset"))
if add_offset is not None:
raster_handle.offsets = (add_offset,) * raster_handle.count

_write_tags(raster_handle=raster_handle, tags=tags)
_write_band_description(raster_handle=raster_handle, xarray_dataset=xarray_dataset)


def _ensure_nodata_dtype(original_nodata, new_dtype):
"""
Convert the nodata to the new datatype and raise warning
Expand Down
46 changes: 31 additions & 15 deletions test/integration/test_integration__io.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,25 +484,41 @@ def test_utm():
assert "y" not in rioda.coords


def test_band_as_variable(open_rasterio):
def test_band_as_variable(open_rasterio, tmp_path):
test_raster = tmp_path / "test.tif"

with create_tmp_geotiff() as (tmp_file, expected):
with open_rasterio(
tmp_file, band_as_variable=True, mask_and_scale=False
) as riods:
for band in (1, 2, 3):
band_name = f"band_{band}"
assert_allclose(riods[band_name], expected.sel(band=band).drop("band"))
assert riods[band_name].attrs["BAND"] == band
assert riods[band_name].attrs["scale_factor"] == 1.0
assert riods[band_name].attrs["add_offset"] == 0.0
assert riods[band_name].attrs["long_name"] == f"d{band}"
assert riods[band_name].attrs["units"] == f"u{band}"
assert riods[band_name].rio.crs == expected.rio.crs
assert_array_equal(
riods[band_name].rio.resolution(), expected.rio.resolution()
)
assert isinstance(riods[band_name].rio._cached_transform(), Affine)
assert riods[band_name].rio.nodata is None

def _check_raster(raster_ds):
for band in (1, 2, 3):
band_name = f"band_{band}"
assert_allclose(
raster_ds[band_name], expected.sel(band=band).drop("band")
)
assert raster_ds[band_name].attrs["BAND"] == band
assert raster_ds[band_name].attrs["scale_factor"] == 1.0
assert raster_ds[band_name].attrs["add_offset"] == 0.0
assert raster_ds[band_name].attrs["long_name"] == f"d{band}"
assert raster_ds[band_name].attrs["units"] == f"u{band}"
assert raster_ds[band_name].rio.crs == expected.rio.crs
assert_array_equal(
raster_ds[band_name].rio.resolution(), expected.rio.resolution()
)
assert isinstance(
raster_ds[band_name].rio._cached_transform(), Affine
)
assert raster_ds[band_name].rio.nodata is None

_check_raster(riods)
# test roundtrip
riods.rio.to_raster(test_raster)
with open_rasterio(
test_raster, band_as_variable=True, mask_and_scale=False
) as riods_round:
_check_raster(riods_round)


def test_platecarree():
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,7 +1740,7 @@ def test_to_raster__dataset__mask_and_scale(chunks, tmpdir):
with rioxarray.open_rasterio(str(output_raster)) as rdscompare:
assert rdscompare.scale_factor == 0.1
assert rdscompare.add_offset == 220.0
assert rdscompare.long_name == "air_temperature"
assert rdscompare.long_name == "tmmx"
assert rdscompare.rio.crs == rds.rio.crs
assert rdscompare.rio.nodata == rds.air_temperature.rio.nodata

Expand Down

0 comments on commit d7a1cc7

Please sign in to comment.