From d7a1cc7af38600733dd213c76c007325319605de Mon Sep 17 00:00:00 2001 From: snowman2 Date: Thu, 8 Dec 2022 15:50:58 -0600 Subject: [PATCH] BUG:dataset: Fix writing tags for bands & prevent overwriting long_name attribute --- docs/history.rst | 2 + rioxarray/raster_dataset.py | 9 ++- rioxarray/raster_writer.py | 67 ++++++++++++------- test/integration/test_integration__io.py | 46 ++++++++----- .../integration/test_integration_rioxarray.py | 2 +- 5 files changed, 83 insertions(+), 43 deletions(-) diff --git a/docs/history.rst b/docs/history.rst index c5127006..ed7971d3 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -3,6 +3,8 @@ History Latest ------ +- BUG:dataset: Fix writing tags for bands (issue #615) +- BUG:dataset: prevent overwriting long_name attribute (pull #616) 0.13.1 ------ diff --git a/rioxarray/raster_dataset.py b/rioxarray/raster_dataset.py index d0a3275f..ff262a81 100644 --- a/rioxarray/raster_dataset.py +++ b/rioxarray/raster_dataset.py @@ -497,18 +497,23 @@ def to_raster( """ variable_dim = f"band_{uuid4()}" data_array = self._obj.to_array(dim=variable_dim) - # write data array names to raster - data_array.attrs["long_name"] = data_array[variable_dim].values.tolist() # ensure raster metadata preserved scales = [] offsets = [] nodatavals = [] + band_tags = [] + long_name = [] for data_var in data_array[variable_dim].values: scales.append(self._obj[data_var].attrs.get("scale_factor", 1.0)) offsets.append(self._obj[data_var].attrs.get("add_offset", 0.0)) + long_name.append(self._obj[data_var].attrs.get("long_name", data_var)) nodatavals.append(self._obj[data_var].rio.nodata) + band_tags.append(self._obj[data_var].attrs.copy()) data_array.attrs["scales"] = scales data_array.attrs["offsets"] = offsets + data_array.attrs["band_tags"] = band_tags + data_array.attrs["long_name"] = long_name + nodata = nodatavals[0] if ( all(nodataval == nodata for nodataval in nodatavals) diff --git a/rioxarray/raster_writer.py b/rioxarray/raster_writer.py index 7d49d35c..105916a2 100644 --- a/rioxarray/raster_writer.py +++ b/rioxarray/raster_writer.py @@ -36,32 +36,10 @@ def is_dask_collection(_) -> bool: # type: ignore # Note: transform & crs are removed in write_transform/write_crs -def _write_metatata_to_raster(raster_handle, xarray_dataset, tags): +def _write_tags(raster_handle, tags): """ - Write the metadata stored in the xarray object to raster metadata + Write tags to raster dataset """ - tags = ( - xarray_dataset.attrs.copy() - if tags is None - else {**xarray_dataset.attrs, **tags} - ) - - # write scales and offsets - try: - raster_handle.scales = tags["scales"] - except KeyError: - scale_factor = tags.get( - "scale_factor", xarray_dataset.encoding.get("scale_factor") - ) - if scale_factor is not None: - raster_handle.scales = (scale_factor,) * raster_handle.count - try: - raster_handle.offsets = tags["offsets"] - except KeyError: - add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset")) - if add_offset is not None: - raster_handle.offsets = (add_offset,) * raster_handle.count - # filter out attributes that should be written in a different location skip_tags = ( UNWANTED_RIO_ATTRS @@ -80,10 +58,19 @@ def _write_metatata_to_raster(raster_handle, xarray_dataset, tags): # in this case, it will be stored in the raster description if not isinstance(tags.get("long_name"), str): skip_tags += ("long_name",) + band_tags = tags.pop("band_tags", []) tags = {key: value for key, value in tags.items() if key not in skip_tags} raster_handle.update_tags(**tags) - # write band name information + if isinstance(band_tags, list): + for iii, band_tag in enumerate(band_tags): + raster_handle.update_tags(iii + 1, **band_tag) + + +def _write_band_description(raster_handle, xarray_dataset): + """ + Write band descriptions using the long name + """ long_name = xarray_dataset.attrs.get("long_name") if isinstance(long_name, (tuple, list)): if len(long_name) != raster_handle.count: @@ -100,6 +87,36 @@ def _write_metatata_to_raster(raster_handle, xarray_dataset, tags): raster_handle.set_band_description(iii + 1, band_description) +def _write_metatata_to_raster(raster_handle, xarray_dataset, tags): + """ + Write the metadata stored in the xarray object to raster metadata + """ + tags = ( + xarray_dataset.attrs.copy() + if tags is None + else {**xarray_dataset.attrs, **tags} + ) + + # write scales and offsets + try: + raster_handle.scales = tags["scales"] + except KeyError: + scale_factor = tags.get( + "scale_factor", xarray_dataset.encoding.get("scale_factor") + ) + if scale_factor is not None: + raster_handle.scales = (scale_factor,) * raster_handle.count + try: + raster_handle.offsets = tags["offsets"] + except KeyError: + add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset")) + if add_offset is not None: + raster_handle.offsets = (add_offset,) * raster_handle.count + + _write_tags(raster_handle=raster_handle, tags=tags) + _write_band_description(raster_handle=raster_handle, xarray_dataset=xarray_dataset) + + def _ensure_nodata_dtype(original_nodata, new_dtype): """ Convert the nodata to the new datatype and raise warning diff --git a/test/integration/test_integration__io.py b/test/integration/test_integration__io.py index 0353c279..e5375cd0 100644 --- a/test/integration/test_integration__io.py +++ b/test/integration/test_integration__io.py @@ -484,25 +484,41 @@ def test_utm(): assert "y" not in rioda.coords -def test_band_as_variable(open_rasterio): +def test_band_as_variable(open_rasterio, tmp_path): + test_raster = tmp_path / "test.tif" + with create_tmp_geotiff() as (tmp_file, expected): with open_rasterio( tmp_file, band_as_variable=True, mask_and_scale=False ) as riods: - for band in (1, 2, 3): - band_name = f"band_{band}" - assert_allclose(riods[band_name], expected.sel(band=band).drop("band")) - assert riods[band_name].attrs["BAND"] == band - assert riods[band_name].attrs["scale_factor"] == 1.0 - assert riods[band_name].attrs["add_offset"] == 0.0 - assert riods[band_name].attrs["long_name"] == f"d{band}" - assert riods[band_name].attrs["units"] == f"u{band}" - assert riods[band_name].rio.crs == expected.rio.crs - assert_array_equal( - riods[band_name].rio.resolution(), expected.rio.resolution() - ) - assert isinstance(riods[band_name].rio._cached_transform(), Affine) - assert riods[band_name].rio.nodata is None + + def _check_raster(raster_ds): + for band in (1, 2, 3): + band_name = f"band_{band}" + assert_allclose( + raster_ds[band_name], expected.sel(band=band).drop("band") + ) + assert raster_ds[band_name].attrs["BAND"] == band + assert raster_ds[band_name].attrs["scale_factor"] == 1.0 + assert raster_ds[band_name].attrs["add_offset"] == 0.0 + assert raster_ds[band_name].attrs["long_name"] == f"d{band}" + assert raster_ds[band_name].attrs["units"] == f"u{band}" + assert raster_ds[band_name].rio.crs == expected.rio.crs + assert_array_equal( + raster_ds[band_name].rio.resolution(), expected.rio.resolution() + ) + assert isinstance( + raster_ds[band_name].rio._cached_transform(), Affine + ) + assert raster_ds[band_name].rio.nodata is None + + _check_raster(riods) + # test roundtrip + riods.rio.to_raster(test_raster) + with open_rasterio( + test_raster, band_as_variable=True, mask_and_scale=False + ) as riods_round: + _check_raster(riods_round) def test_platecarree(): diff --git a/test/integration/test_integration_rioxarray.py b/test/integration/test_integration_rioxarray.py index 983c0e8e..2e100186 100644 --- a/test/integration/test_integration_rioxarray.py +++ b/test/integration/test_integration_rioxarray.py @@ -1740,7 +1740,7 @@ def test_to_raster__dataset__mask_and_scale(chunks, tmpdir): with rioxarray.open_rasterio(str(output_raster)) as rdscompare: assert rdscompare.scale_factor == 0.1 assert rdscompare.add_offset == 220.0 - assert rdscompare.long_name == "air_temperature" + assert rdscompare.long_name == "tmmx" assert rdscompare.rio.crs == rds.rio.crs assert rdscompare.rio.nodata == rds.air_temperature.rio.nodata