Skip to content

Commit

Permalink
improve cutline handling
Browse files Browse the repository at this point in the history
  • Loading branch information
yellowcap committed Apr 24, 2023
1 parent 94f8b53 commit e93b796
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 31 deletions.
37 changes: 26 additions & 11 deletions docs/src/advanced/feature.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,38 @@ with Reader("my-tif.tif") as cog:
img: ImageData = cog.feature(geojson_feature, max_size=1024) # we limit the max_size to 1024
```

Under the hood, the `.feature` method uses `GDALWarpVRT`'s `cutline` option and
the `.part()` method. The below process is roughly what `.feature` does for you.
Under the hood, the `.feature` method uses rasterio's (`rasterize`)[https://rasterio.readthedocs.io/en/latest/api/rasterio.features.html#rasterio.features.rasterize]
function and the `.part()` method. The below process is roughly what `.feature` does for you.

```python
from rasterio.features import rasterize, bounds as featureBounds

from rio_tiler.io import Reader
from rio_tiler.utils import create_cutline
from rasterio.features import bounds as featureBounds

# Use Reader to open and read the dataset
with Reader("my_tif.tif") as cog:
# Create WTT Cutline
cutline = create_cutline(cog.dataset, feat, geometry_crs="epsg:4326")

# Get BBOX of the polygon
bbox = featureBounds(feat)

# Read part of the data (bbox) and use the cutline to mask the data
data, mask = cog.part(bbox, vrt_options={'cutline': cutline}, max_size=1024)
# Read part of the data overlapping with the geometry bbox
# assuming that the geometry coordinates are in web mercator
img = cog.part(bbox, bounds_crs=f"EPSG:3857", max_size=1024)

# Rasterize geometry using the same geotransform parameters
cutline = rasterize(
[feat],
out_shape=(img.height, img.width),
transform=img.transform,
...
)

# Apply geometry mask to imagery
img.array.mask = numpy.where(~cutline, img.array.mask, True)
```

Another interesting fact about the `cutline` option is that it can be used with other methods:
Another interesting way to cut features is to use the GDALWarpVRT's `cutline`
option with the .part(), .preview(), or .tile() methods:

```python
bbox = featureBounds(feat)
Expand All @@ -41,9 +52,13 @@ with Reader("my_tif.tif") as cog:
# Create WTT Cutline
cutline = create_cutline(cog.dataset, feat, geometry_crs="epsg:4326")

# Get a part of the geotiff but use the cutline to mask the data
bbox = featureBounds(feat)
img = cog.part(bbox, vrt_options={'cutline': cutline})

# Get a preview of the whole geotiff but use the cutline to mask the data
data, mask = cog.preview(vrt_options={'cutline': cutline})
img = cog.preview(vrt_options={'cutline': cutline})

# Read a mercator tile and use the cutline to mask the data
data, mask = cog.tile(1, 1, 1, vrt_options={'cutline': cutline})
img = cog.tile(1, 1, 1, vrt_options={'cutline': cutline})
```
27 changes: 22 additions & 5 deletions rio_tiler/io/rasterio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from rasterio import transform
from rasterio.crs import CRS
from rasterio.features import bounds as featureBounds
from rasterio.features import geometry_mask
from rasterio.features import geometry_mask, rasterize
from rasterio.io import DatasetReader, DatasetWriter, MemoryFile
from rasterio.rio.overview import get_maximum_overview_level
from rasterio.transform import from_bounds as transform_from_bounds
Expand All @@ -28,13 +28,14 @@
ExpressionMixingWarning,
NoOverviewWarning,
PointOutsideBounds,
RioTilerError,
TileOutsideBounds,
)
from rio_tiler.expression import parse_expression
from rio_tiler.io.base import BaseReader
from rio_tiler.models import BandStatistics, ImageData, Info, PointData
from rio_tiler.types import BBox, Indexes, NumType, RIOResampling
from rio_tiler.utils import create_cutline, has_alpha_band, has_mask_band
from rio_tiler.utils import _validate_shape_input, has_alpha_band, has_mask_band


@attr.s
Expand Down Expand Up @@ -501,6 +502,7 @@ def feature(
height: Optional[int] = None,
width: Optional[int] = None,
buffer: Optional[NumType] = None,
all_touched: Optional[bool] = False,
**kwargs: Any,
) -> ImageData:
"""Read part of a Dataset defined by a geojson feature.
Expand All @@ -515,23 +517,24 @@ def feature(
height (int, optional): Output height of the array.
width (int, optional): Output width of the array.
buffer (int or float, optional): Buffer on each side of the given aoi. It must be a multiple of `0.5`. Output **image size** will be expanded to `output imagesize + 2 * buffer` (e.g 0.5 = 257x257, 1.0 = 258x258).
all_touched (bool, optional): Shape rasterization parameter. If True, all pixels touched the input shape will be burned in. If false, only pixels whose center is within the input shape or that are selected by Bresenham's line algorithm will be burned in.
kwargs (optional): Options to forward to the `Reader.part` method.
Returns:
rio_tiler.models.ImageData: ImageData instance with data, mask and input spatial info.
"""
shape = _validate_shape_input(shape)

if not dst_crs:
dst_crs = shape_crs

# Get BBOX of the polygon
bbox = featureBounds(shape)

cutline = create_cutline(self.dataset, shape, geometry_crs=shape_crs)
vrt_options = kwargs.pop("vrt_options", {})
vrt_options.update({"cutline": cutline})

return self.part(
img = self.part(
bbox,
dst_crs=dst_crs,
bounds_crs=shape_crs,
Expand All @@ -545,6 +548,20 @@ def feature(
**kwargs,
)

cutline_mask = rasterize(
[shape],
out_shape=(img.height, img.width),
transform=img.transform,
all_touched=all_touched,
default_value=0,
fill=1,
dtype="uint8",
).astype("bool")
img.cutline_mask = cutline_mask
img.array.mask = numpy.where(~cutline_mask, img.array.mask, True)

return img

def read(
self,
indexes: Optional[Indexes] = None,
Expand Down
10 changes: 3 additions & 7 deletions rio_tiler/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
from morecantile import Tile, TileMatrixSet
from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.features import is_valid_geom
from rasterio.rio.overview import get_maximum_overview_level
from rasterio.transform import from_bounds, rowcol
from rasterio.warp import calculate_default_transform
from rasterio.warp import transform as transform_coords

from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS
from rio_tiler.errors import PointOutsideBounds, RioTilerError, TileOutsideBounds
from rio_tiler.errors import PointOutsideBounds, TileOutsideBounds
from rio_tiler.io.base import BaseReader
from rio_tiler.models import BandStatistics, ImageData, Info, PointData
from rio_tiler.types import BBox, WarpResampling
from rio_tiler.utils import _validate_shape_input

try:
import xarray
Expand Down Expand Up @@ -373,11 +373,7 @@ def feature(
if not dst_crs:
dst_crs = shape_crs

if "geometry" in shape:
shape = shape["geometry"]

if not is_valid_geom(shape):
raise RioTilerError("Invalid geometry")
shape = _validate_shape_input(shape)

ds = self.input.rio.clip([shape], crs=shape_crs)

Expand Down
5 changes: 5 additions & 0 deletions rio_tiler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ class ImageData:
dataset_statistics: Optional[Sequence[Tuple[float, float]]] = attr.ib(
default=None, kw_only=True
)
cutline_mask: Optional[numpy.ndarray] = attr.ib(default=None)

@band_names.default
def _default_names(self):
Expand Down Expand Up @@ -429,12 +430,15 @@ def create_from_list(cls, data: Sequence["ImageData"]) -> "ImageData":
max_h, max_w = max(h), max(w)
for img in data:
if img.height == max_h and img.width == max_w:
cutline_mask = img.cutline_mask
continue
arr = numpy.ma.MaskedArray(
resize_array(img.array.data, max_h, max_w),
mask=resize_array(img.array.mask * 1, max_h, max_w).astype("bool"),
)
img.array = arr
else:
cutline_mask = data[0].cutline_mask

arr = numpy.ma.concatenate([img.array for img in data])

Expand Down Expand Up @@ -472,6 +476,7 @@ def create_from_list(cls, data: Sequence["ImageData"]) -> "ImageData":
bounds=bounds,
band_names=band_names,
dataset_statistics=dataset_statistics,
cutline_mask=cutline_mask,
)

def as_masked(self) -> numpy.ma.MaskedArray:
Expand Down
12 changes: 10 additions & 2 deletions rio_tiler/mosaic/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class MosaicMethodBase(abc.ABC):

mosaic: Optional[numpy.ma.MaskedArray] = field(default=None, init=False)
exit_when_filled: bool = field(default=False, init=False)
cutline_mask: Optional[numpy.ndarray] = field(default=None, init=False)

@property
def is_done(self) -> bool:
Expand All @@ -25,8 +26,15 @@ def is_done(self) -> bool:
if self.mosaic is None:
return False

if self.exit_when_filled and not numpy.ma.is_masked(self.mosaic):
return True
if self.exit_when_filled:
if (
self.cutline_mask is not None
and numpy.sum(numpy.where(~self.cutline_mask, self.mosaic.mask, False))
== 0
):
return True
elif not numpy.ma.is_masked(self.mosaic):
return True

return False

Expand Down
2 changes: 2 additions & 0 deletions rio_tiler/mosaic/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def mosaic_reader(
bounds = img.bounds
band_names = img.band_names

pixel_selection.cutline_mask = img.cutline_mask

assets_used.append(asset)
pixel_selection.feed(img.array)

Expand Down
19 changes: 13 additions & 6 deletions rio_tiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,7 @@ def create_cutline(
str: WKT geometry in form of `POLYGON ((x y, x y, ...)))
"""
if "geometry" in geometry:
geometry = geometry["geometry"]

if not is_valid_geom(geometry):
raise RioTilerError("Invalid geometry")

geometry = _validate_shape_input(geometry)
geom_type = geometry["type"]

if geometry_crs:
Expand Down Expand Up @@ -636,3 +631,15 @@ def normalize_bounds(bounds: BBox) -> BBox:
max(bounds[0], bounds[2]),
max(bounds[1], bounds[3]),
)


def _validate_shape_input(shape: Dict) -> Dict:
"""Ensure input shape is valid and reduce features to geometry"""

if "geometry" in shape:
shape = shape["geometry"]

if not is_valid_geom(shape):
raise RioTilerError("Invalid geometry")

return shape

0 comments on commit e93b796

Please sign in to comment.