Skip to content

Commit

Permalink
ENH: Added band_as_variable option to open_rasterio (#600)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Nov 2, 2022
1 parent 697cb9e commit 9643b46
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ enable=c-extension-no-member


[FORMAT]
max-module-lines=1200
max-module-lines=1250

[DESIGN]
max-locals=20
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Latest

0.12.4
------
- ENH: Added band_as_variable option to open_rasterio (issue #296)
- BUG: Pass warp_extras dictionary to raster.vrt.WarpedVRT (issue #598)

0.12.3
Expand Down
177 changes: 172 additions & 5 deletions rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import contextlib
import functools
import importlib.metadata
import os
import re
Expand Down Expand Up @@ -38,7 +39,91 @@
RASTERIO_LOCK = SerializableLock()
NO_LOCK = contextlib.nullcontext()

RasterioReader = Union[rasterio.io.DatasetReader, rasterio.vrt.WarpedVRT]

class SingleBandDatasetReader:
"""
Hack to have a DatasetReader behave like it only has one band
"""

def __init__(self, riods, bidx) -> None:
self._riods = riods
self._bidx = bidx

def __getattr__(self, __name: str) -> Any:
return getattr(self._riods, __name)

@property
def count(self):
"""
int: band count
"""
return 1

@property
def nodata(self):
"""
Nodata value for the band
"""
return self._riods.nodatavals[self._bidx]

@property
def offsets(self):
"""
Offset value for the band
"""
return [self._riods.offsets[self._bidx]]

@property
def scales(self):
"""
Scale value for the band
"""
return [self._riods.scales[self._bidx]]

@property
def units(self):
"""
Unit for the band
"""
return [self._riods.units[self._bidx]]

@property
def descriptions(self):
"""
Description for the band
"""
return [self._riods.descriptions[self._bidx]]

@property
def dtypes(self):
"""
dtype for the band
"""
return [self._riods.dtypes[self._bidx]]

@property
def indexes(self):
"""
indexes for the band
"""
return [self._riods.indexes[self._bidx]]

def read(self, indexes=None, **kwargs): # pylint: disable=unused-argument
"""
read data for the band
"""
return self._riods.read(indexes=self._bidx + 1, **kwargs)

def tags(self, bidx=None, **kwargs): # pylint: disable=unused-argument
"""
read tags for the band
"""
return self._riods.tags(bidx=self._bidx + 1, **kwargs)


RasterioReader = Union[
rasterio.io.DatasetReader, rasterio.vrt.WarpedVRT, SingleBandDatasetReader
]


try:
Expand Down Expand Up @@ -711,6 +796,49 @@ def _load_subdatasets(
return dataset


def _load_bands_as_variables(
riods: RasterioReader,
parse_coordinates: bool,
chunks: Optional[Union[int, Tuple, Dict]],
cache: Optional[bool],
lock: Any,
masked: bool,
mask_and_scale: bool,
decode_times: bool,
decode_timedelta: Optional[bool],
**open_kwargs,
) -> Union[Dataset, List[Dataset]]:
"""
Load in rasterio bands as variables
"""
global_tags = _parse_tags(riods.tags())
data_vars = {}
for band in riods.indexes:
band_riods = SingleBandDatasetReader(
riods=riods,
bidx=band - 1,
)
band_name = f"band_{band}"
data_vars[band_name] = (
open_rasterio( # type: ignore
band_riods,
parse_coordinates=band == 1 and parse_coordinates,
chunks=chunks,
cache=cache,
lock=lock,
masked=masked,
mask_and_scale=mask_and_scale,
default_name=band_name,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
**open_kwargs,
)
.squeeze() # type: ignore
.drop("band") # type: ignore
)
return Dataset(data_vars, attrs=global_tags)


def _prepare_dask(
result: DataArray,
riods: RasterioReader,
Expand Down Expand Up @@ -785,9 +913,23 @@ def _handle_encoding(
)


def _single_band_open(*args, bidx=0, **kwargs):
"""
Open file as if it only has a single band
"""
return SingleBandDatasetReader(
riods=rasterio.open(*args, **kwargs),
bidx=bidx,
)


def open_rasterio(
filename: Union[
str, os.PathLike, rasterio.io.DatasetReader, rasterio.vrt.WarpedVRT
str,
os.PathLike,
rasterio.io.DatasetReader,
rasterio.vrt.WarpedVRT,
SingleBandDatasetReader,
],
parse_coordinates: Optional[bool] = None,
chunks: Optional[Union[int, Tuple, Dict]] = None,
Expand All @@ -800,6 +942,7 @@ def open_rasterio(
default_name: Optional[str] = None,
decode_times: bool = True,
decode_timedelta: Optional[bool] = None,
band_as_variable: bool = False,
**open_kwargs,
) -> Union[Dataset, DataArray, List[Dataset]]:
# pylint: disable=too-many-statements,too-many-locals,too-many-branches
Expand All @@ -812,6 +955,8 @@ def open_rasterio(
<http://web.archive.org/web/20160326194152/http://remotesensing.org/geotiff/spec/geotiff2.5.html#2.5.2>`_
for more information).
.. versionadded:: 0.13 band_as_variable
Parameters
----------
filename: str, rasterio.io.DatasetReader, or rasterio.vrt.WarpedVRT
Expand Down Expand Up @@ -866,6 +1011,8 @@ def open_rasterio(
{“days”, “hours”, “minutes”, “seconds”, “milliseconds”, “microseconds”}
into timedelta objects. If False, leave them encoded as numbers.
If None (default), assume the same value of decode_time.
band_as_variable: bool, default=False
If True, will load bands in a raster to separate variables.
**open_kwargs: kwargs, optional
Optional keyword arguments to pass into :func:`rasterio.open`.
Expand All @@ -877,7 +1024,13 @@ def open_rasterio(
parse_coordinates = True if parse_coordinates is None else parse_coordinates
masked = masked or mask_and_scale
vrt_params = None
if isinstance(filename, rasterio.io.DatasetReader):
file_opener = rasterio.open
if isinstance(filename, SingleBandDatasetReader):
file_opener = functools.partial(
_single_band_open,
bidx=filename._bidx,
)
if isinstance(filename, (rasterio.io.DatasetReader, SingleBandDatasetReader)):
filename = filename.name
elif isinstance(filename, rasterio.vrt.WarpedVRT):
vrt = filename
Expand Down Expand Up @@ -909,13 +1062,27 @@ def open_rasterio(
with warnings.catch_warnings(record=True) as rio_warnings:
if lock is not NO_LOCK and isinstance(filename, (str, os.PathLike)):
manager: FileManager = CachingFileManager(
rasterio.open, filename, lock=lock, mode="r", kwargs=open_kwargs
file_opener, filename, lock=lock, mode="r", kwargs=open_kwargs
)
else:
manager = URIManager(rasterio.open, filename, mode="r", kwargs=open_kwargs)
manager = URIManager(file_opener, filename, mode="r", kwargs=open_kwargs)
riods = manager.acquire()
captured_warnings = rio_warnings.copy()

if band_as_variable:
return _load_bands_as_variables(
riods=riods,
parse_coordinates=parse_coordinates,
chunks=chunks,
cache=cache,
lock=lock,
masked=masked,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
**open_kwargs,
)

# raise the NotGeoreferencedWarning if applicable
for rio_warning in captured_warnings:
if not riods.subdatasets or not isinstance(
Expand Down
21 changes: 21 additions & 0 deletions test/integration/test_integration__io.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ def create_tmp_geotiff(
) as s:
for attr, val in additional_attrs.items():
setattr(s, attr, val)
for band in range(1, nz + 1):
s.update_tags(band, BAND=band)
s.write(data, **write_kwargs)
dx, dy = s.res[0], -s.res[1]
tt = s.transform
Expand Down Expand Up @@ -480,6 +482,25 @@ def test_utm():
assert "y" not in rioda.coords


def test_band_as_variable():
with create_tmp_geotiff() as (tmp_file, expected):
with rioxarray.open_rasterio(tmp_file, band_as_variable=True) as riods:
for band in (1, 2, 3):
band_name = f"band_{band}"
assert_allclose(riods[band_name], expected.sel(band=band).drop("band"))
assert riods[band_name].attrs["BAND"] == band
assert riods[band_name].attrs["scale_factor"] == 1.0
assert riods[band_name].attrs["add_offset"] == 0.0
assert riods[band_name].attrs["long_name"] == f"d{band}"
assert riods[band_name].attrs["units"] == f"u{band}"
assert riods[band_name].rio.crs == expected.rio.crs
assert_array_equal(
riods[band_name].rio.resolution(), expected.rio.resolution()
)
assert isinstance(riods[band_name].rio._cached_transform(), Affine)
assert riods[band_name].rio.nodata is None


def test_platecarree():
with create_tmp_geotiff(
8,
Expand Down

0 comments on commit 9643b46

Please sign in to comment.