Skip to content

Commit

Permalink
Improve cutline handling
Browse files Browse the repository at this point in the history
Closes #588
  • Loading branch information
yellowcap committed Apr 26, 2023
1 parent 94f8b53 commit 8dfa736
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 31 deletions.
6 changes: 6 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

- Fix potential issue when getting statistics for non-valid data
- add `rio-tiler.mosaic.methods.PixelSelectionMethod` enums with all defaults methods
- Add `rio-tiler.utils._validate_shape_input` function to check geojson feature inputs
- Change cutline handling in the `rio-tiler.io.rasterio.Reader.feature` method. Feature
cutlines are now rasterized into numpy arrays and applied as masks instead of using
the cutline vrt_option. These masks are tracked in the `rio-tiler.models.ImageData.cutline_mask`
attribute, which are used in `rio-tiler.mosaic.methods.base.MosaicMethodBase` to stop
mosaic building as soon as all pixels in a feature are populated

**breaking changes**

Expand Down
39 changes: 28 additions & 11 deletions docs/src/advanced/feature.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,56 @@ with Reader("my-tif.tif") as cog:
img: ImageData = cog.feature(geojson_feature, max_size=1024) # we limit the max_size to 1024
```

Under the hood, the `.feature` method uses `GDALWarpVRT`'s `cutline` option and
the `.part()` method. The below process is roughly what `.feature` does for you.
Under the hood, the `.feature` method uses rasterio's [`rasterize`](https://rasterio.readthedocs.io/en/latest/api/rasterio.features.html#rasterio.features.rasterize)
function and the `.part()` method. The below process is roughly what `.feature` does for you.

```python
from rasterio.features import rasterize, bounds as featureBounds

from rio_tiler.io import Reader
from rio_tiler.utils import create_cutline
from rasterio.features import bounds as featureBounds

# Use Reader to open and read the dataset
with Reader("my_tif.tif") as cog:
# Create WTT Cutline
cutline = create_cutline(cog.dataset, feat, geometry_crs="epsg:4326")

# Get BBOX of the polygon
bbox = featureBounds(feat)

# Read part of the data (bbox) and use the cutline to mask the data
data, mask = cog.part(bbox, vrt_options={'cutline': cutline}, max_size=1024)
# Read part of the data overlapping with the geometry bbox
# assuming that the geometry coordinates are in web mercator
img = cog.part(bbox, bounds_crs=f"EPSG:3857", max_size=1024)

# Rasterize geometry using the same geotransform parameters
cutline = rasterize(
[feat],
out_shape=(img.height, img.width),
transform=img.transform,
...
)

# Apply geometry mask to imagery
img.array.mask = numpy.where(~cutline, img.array.mask, True)
```

Another interesting fact about the `cutline` option is that it can be used with other methods:
Another interesting way to cut features is to use the GDALWarpVRT's `cutline`
option with the .part(), .preview(), or .tile() methods:

```python
from rio_tiler.utils import create_cutline

bbox = featureBounds(feat)

# Use Reader to open and read the dataset
with Reader("my_tif.tif") as cog:
# Create WTT Cutline
cutline = create_cutline(cog.dataset, feat, geometry_crs="epsg:4326")

# Get a part of the geotiff but use the cutline to mask the data
bbox = featureBounds(feat)
img = cog.part(bbox, vrt_options={'cutline': cutline})

# Get a preview of the whole geotiff but use the cutline to mask the data
data, mask = cog.preview(vrt_options={'cutline': cutline})
img = cog.preview(vrt_options={'cutline': cutline})

# Read a mercator tile and use the cutline to mask the data
data, mask = cog.tile(1, 1, 1, vrt_options={'cutline': cutline})
img = cog.tile(1, 1, 1, vrt_options={'cutline': cutline})
```
27 changes: 22 additions & 5 deletions rio_tiler/io/rasterio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from rasterio import transform
from rasterio.crs import CRS
from rasterio.features import bounds as featureBounds
from rasterio.features import geometry_mask
from rasterio.features import geometry_mask, rasterize
from rasterio.io import DatasetReader, DatasetWriter, MemoryFile
from rasterio.rio.overview import get_maximum_overview_level
from rasterio.transform import from_bounds as transform_from_bounds
Expand All @@ -28,13 +28,14 @@
ExpressionMixingWarning,
NoOverviewWarning,
PointOutsideBounds,
RioTilerError,
TileOutsideBounds,
)
from rio_tiler.expression import parse_expression
from rio_tiler.io.base import BaseReader
from rio_tiler.models import BandStatistics, ImageData, Info, PointData
from rio_tiler.types import BBox, Indexes, NumType, RIOResampling
from rio_tiler.utils import create_cutline, has_alpha_band, has_mask_band
from rio_tiler.utils import _validate_shape_input, has_alpha_band, has_mask_band


@attr.s
Expand Down Expand Up @@ -501,6 +502,7 @@ def feature(
height: Optional[int] = None,
width: Optional[int] = None,
buffer: Optional[NumType] = None,
all_touched: Optional[bool] = False,
**kwargs: Any,
) -> ImageData:
"""Read part of a Dataset defined by a geojson feature.
Expand All @@ -515,23 +517,24 @@ def feature(
height (int, optional): Output height of the array.
width (int, optional): Output width of the array.
buffer (int or float, optional): Buffer on each side of the given aoi. It must be a multiple of `0.5`. Output **image size** will be expanded to `output imagesize + 2 * buffer` (e.g 0.5 = 257x257, 1.0 = 258x258).
all_touched (bool, optional): Shape rasterization parameter. If `True`, all pixels touched by the shape will be included. If `False`, only those whose center point is within the shape will be included.
kwargs (optional): Options to forward to the `Reader.part` method.
Returns:
rio_tiler.models.ImageData: ImageData instance with data, mask and input spatial info.
"""
shape = _validate_shape_input(shape)

if not dst_crs:
dst_crs = shape_crs

# Get BBOX of the polygon
bbox = featureBounds(shape)

cutline = create_cutline(self.dataset, shape, geometry_crs=shape_crs)
vrt_options = kwargs.pop("vrt_options", {})
vrt_options.update({"cutline": cutline})

return self.part(
img = self.part(
bbox,
dst_crs=dst_crs,
bounds_crs=shape_crs,
Expand All @@ -545,6 +548,20 @@ def feature(
**kwargs,
)

cutline_mask = rasterize(
[shape],
out_shape=(img.height, img.width),
transform=img.transform,
all_touched=all_touched,
default_value=0,
fill=1,
dtype="uint8",
).astype("bool")
img.cutline_mask = cutline_mask
img.array.mask = numpy.where(~cutline_mask, img.array.mask, True)

return img

def read(
self,
indexes: Optional[Indexes] = None,
Expand Down
10 changes: 3 additions & 7 deletions rio_tiler/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
from morecantile import Tile, TileMatrixSet
from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.features import is_valid_geom
from rasterio.rio.overview import get_maximum_overview_level
from rasterio.transform import from_bounds, rowcol
from rasterio.warp import calculate_default_transform
from rasterio.warp import transform as transform_coords

from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS
from rio_tiler.errors import PointOutsideBounds, RioTilerError, TileOutsideBounds
from rio_tiler.errors import PointOutsideBounds, TileOutsideBounds
from rio_tiler.io.base import BaseReader
from rio_tiler.models import BandStatistics, ImageData, Info, PointData
from rio_tiler.types import BBox, WarpResampling
from rio_tiler.utils import _validate_shape_input

try:
import xarray
Expand Down Expand Up @@ -373,11 +373,7 @@ def feature(
if not dst_crs:
dst_crs = shape_crs

if "geometry" in shape:
shape = shape["geometry"]

if not is_valid_geom(shape):
raise RioTilerError("Invalid geometry")
shape = _validate_shape_input(shape)

ds = self.input.rio.clip([shape], crs=shape_crs)

Expand Down
5 changes: 5 additions & 0 deletions rio_tiler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ class ImageData:
dataset_statistics: Optional[Sequence[Tuple[float, float]]] = attr.ib(
default=None, kw_only=True
)
cutline_mask: Optional[numpy.ndarray] = attr.ib(default=None)

@band_names.default
def _default_names(self):
Expand Down Expand Up @@ -429,12 +430,15 @@ def create_from_list(cls, data: Sequence["ImageData"]) -> "ImageData":
max_h, max_w = max(h), max(w)
for img in data:
if img.height == max_h and img.width == max_w:
cutline_mask = img.cutline_mask
continue
arr = numpy.ma.MaskedArray(
resize_array(img.array.data, max_h, max_w),
mask=resize_array(img.array.mask * 1, max_h, max_w).astype("bool"),
)
img.array = arr
else:
cutline_mask = data[0].cutline_mask

arr = numpy.ma.concatenate([img.array for img in data])

Expand Down Expand Up @@ -472,6 +476,7 @@ def create_from_list(cls, data: Sequence["ImageData"]) -> "ImageData":
bounds=bounds,
band_names=band_names,
dataset_statistics=dataset_statistics,
cutline_mask=cutline_mask,
)

def as_masked(self) -> numpy.ma.MaskedArray:
Expand Down
12 changes: 10 additions & 2 deletions rio_tiler/mosaic/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class MosaicMethodBase(abc.ABC):

mosaic: Optional[numpy.ma.MaskedArray] = field(default=None, init=False)
exit_when_filled: bool = field(default=False, init=False)
cutline_mask: Optional[numpy.ndarray] = field(default=None, init=False)

@property
def is_done(self) -> bool:
Expand All @@ -25,8 +26,15 @@ def is_done(self) -> bool:
if self.mosaic is None:
return False

if self.exit_when_filled and not numpy.ma.is_masked(self.mosaic):
return True
if self.exit_when_filled:
if (
self.cutline_mask is not None
and numpy.sum(numpy.where(~self.cutline_mask, self.mosaic.mask, False))
== 0
):
return True
elif not numpy.ma.is_masked(self.mosaic):
return True

return False

Expand Down
2 changes: 2 additions & 0 deletions rio_tiler/mosaic/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def mosaic_reader(
bounds = img.bounds
band_names = img.band_names

pixel_selection.cutline_mask = img.cutline_mask

assets_used.append(asset)
pixel_selection.feed(img.array)

Expand Down
19 changes: 13 additions & 6 deletions rio_tiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,7 @@ def create_cutline(
str: WKT geometry in form of `POLYGON ((x y, x y, ...)))
"""
if "geometry" in geometry:
geometry = geometry["geometry"]

if not is_valid_geom(geometry):
raise RioTilerError("Invalid geometry")

geometry = _validate_shape_input(geometry)
geom_type = geometry["type"]

if geometry_crs:
Expand Down Expand Up @@ -636,3 +631,15 @@ def normalize_bounds(bounds: BBox) -> BBox:
max(bounds[0], bounds[2]),
max(bounds[1], bounds[3]),
)


def _validate_shape_input(shape: Dict) -> Dict:
"""Ensure input shape is valid and reduce features to geometry"""

if "geometry" in shape:
shape = shape["geometry"]

if not is_valid_geom(shape):
raise RioTilerError("Invalid geometry")

return shape

0 comments on commit 8dfa736

Please sign in to comment.