Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve cutline handling #598

Merged
merged 3 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

- Fix potential issue when getting statistics for non-valid data
- add `rio-tiler.mosaic.methods.PixelSelectionMethod` enums with all defaults methods
- Add `rio-tiler.utils._validate_shape_input` function to check geojson feature inputs
- Change cutline handling in the `rio-tiler.io.rasterio.Reader.feature` method. Feature
cutlines are now rasterized into numpy arrays and applied as masks instead of using
the cutline vrt_option. These masks are tracked in the `rio-tiler.models.ImageData.cutline_mask`
attribute, which are used in `rio-tiler.mosaic.methods.base.MosaicMethodBase` to stop
mosaic building as soon as all pixels in a feature are populated

**breaking changes**

Expand Down
39 changes: 28 additions & 11 deletions docs/src/advanced/feature.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,56 @@ with Reader("my-tif.tif") as cog:
img: ImageData = cog.feature(geojson_feature, max_size=1024) # we limit the max_size to 1024
```

Under the hood, the `.feature` method uses `GDALWarpVRT`'s `cutline` option and
the `.part()` method. The below process is roughly what `.feature` does for you.
Under the hood, the `.feature` method uses rasterio's [`rasterize`](https://rasterio.readthedocs.io/en/latest/api/rasterio.features.html#rasterio.features.rasterize)
function and the `.part()` method. The below process is roughly what `.feature` does for you.

```python
from rasterio.features import rasterize, bounds as featureBounds

from rio_tiler.io import Reader
from rio_tiler.utils import create_cutline
from rasterio.features import bounds as featureBounds

# Use Reader to open and read the dataset
with Reader("my_tif.tif") as cog:
# Create WTT Cutline
cutline = create_cutline(cog.dataset, feat, geometry_crs="epsg:4326")

# Get BBOX of the polygon
bbox = featureBounds(feat)

# Read part of the data (bbox) and use the cutline to mask the data
data, mask = cog.part(bbox, vrt_options={'cutline': cutline}, max_size=1024)
# Read part of the data overlapping with the geometry bbox
# assuming that the geometry coordinates are in web mercator
img = cog.part(bbox, bounds_crs=f"EPSG:3857", max_size=1024)

# Rasterize geometry using the same geotransform parameters
cutline = rasterize(
[feat],
out_shape=(img.height, img.width),
transform=img.transform,
...
)

# Apply geometry mask to imagery
img.array.mask = numpy.where(~cutline, img.array.mask, True)
```

Another interesting fact about the `cutline` option is that it can be used with other methods:
Another interesting way to cut features is to use the GDALWarpVRT's `cutline`
option with the .part(), .preview(), or .tile() methods:

```python
from rio_tiler.utils import create_cutline

bbox = featureBounds(feat)
yellowcap marked this conversation as resolved.
Show resolved Hide resolved

# Use Reader to open and read the dataset
with Reader("my_tif.tif") as cog:
# Create WTT Cutline
cutline = create_cutline(cog.dataset, feat, geometry_crs="epsg:4326")

# Get a part of the geotiff but use the cutline to mask the data
bbox = featureBounds(feat)
img = cog.part(bbox, vrt_options={'cutline': cutline})

# Get a preview of the whole geotiff but use the cutline to mask the data
data, mask = cog.preview(vrt_options={'cutline': cutline})
img = cog.preview(vrt_options={'cutline': cutline})

# Read a mercator tile and use the cutline to mask the data
data, mask = cog.tile(1, 1, 1, vrt_options={'cutline': cutline})
img = cog.tile(1, 1, 1, vrt_options={'cutline': cutline})
```
24 changes: 19 additions & 5 deletions rio_tiler/io/rasterio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from rasterio import transform
from rasterio.crs import CRS
from rasterio.features import bounds as featureBounds
from rasterio.features import geometry_mask
from rasterio.features import geometry_mask, rasterize
from rasterio.io import DatasetReader, DatasetWriter, MemoryFile
from rasterio.rio.overview import get_maximum_overview_level
from rasterio.transform import from_bounds as transform_from_bounds
Expand All @@ -34,7 +34,7 @@
from rio_tiler.io.base import BaseReader
from rio_tiler.models import BandStatistics, ImageData, Info, PointData
from rio_tiler.types import BBox, Indexes, NumType, RIOResampling
from rio_tiler.utils import create_cutline, has_alpha_band, has_mask_band
from rio_tiler.utils import _validate_shape_input, has_alpha_band, has_mask_band


@attr.s
Expand Down Expand Up @@ -521,17 +521,17 @@ def feature(
rio_tiler.models.ImageData: ImageData instance with data, mask and input spatial info.

"""
shape = _validate_shape_input(shape)

if not dst_crs:
dst_crs = shape_crs

# Get BBOX of the polygon
bbox = featureBounds(shape)

cutline = create_cutline(self.dataset, shape, geometry_crs=shape_crs)
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved
vrt_options = kwargs.pop("vrt_options", {})
vrt_options.update({"cutline": cutline})

return self.part(
img = self.part(
bbox,
dst_crs=dst_crs,
bounds_crs=shape_crs,
Expand All @@ -545,6 +545,20 @@ def feature(
**kwargs,
)

cutline_mask = rasterize(
[shape],
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved
out_shape=(img.height, img.width),
transform=img.transform,
all_touched=True, # Necesary for matching masks at different resolutions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

default_value=0,
fill=1,
dtype="uint8",
).astype("bool")
img.cutline_mask = cutline_mask
img.array.mask = numpy.where(~cutline_mask, img.array.mask, True)

return img

def read(
self,
indexes: Optional[Indexes] = None,
Expand Down
10 changes: 3 additions & 7 deletions rio_tiler/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
from morecantile import Tile, TileMatrixSet
from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.features import is_valid_geom
from rasterio.rio.overview import get_maximum_overview_level
from rasterio.transform import from_bounds, rowcol
from rasterio.warp import calculate_default_transform
from rasterio.warp import transform as transform_coords

from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS
from rio_tiler.errors import PointOutsideBounds, RioTilerError, TileOutsideBounds
from rio_tiler.errors import PointOutsideBounds, TileOutsideBounds
from rio_tiler.io.base import BaseReader
from rio_tiler.models import BandStatistics, ImageData, Info, PointData
from rio_tiler.types import BBox, WarpResampling
from rio_tiler.utils import _validate_shape_input

try:
import xarray
Expand Down Expand Up @@ -373,11 +373,7 @@ def feature(
if not dst_crs:
dst_crs = shape_crs

if "geometry" in shape:
shape = shape["geometry"]

if not is_valid_geom(shape):
raise RioTilerError("Invalid geometry")
shape = _validate_shape_input(shape)

ds = self.input.rio.clip([shape], crs=shape_crs)

Expand Down
12 changes: 11 additions & 1 deletion rio_tiler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ class ImageData:
dataset_statistics: Optional[Sequence[Tuple[float, float]]] = attr.ib(
default=None, kw_only=True
)
cutline_mask: Optional[numpy.ndarray] = attr.ib(default=None)

@band_names.default
def _default_names(self):
Expand Down Expand Up @@ -421,12 +422,20 @@ def create_from_list(cls, data: Sequence["ImageData"]) -> "ImageData":

"""
h, w = zip(*[(img.height, img.width) for img in data])

# Get cutline mask at highest resolution.
max_h, max_w = max(h), max(w)
cutline_mask = next(
img.cutline_mask
for img in data
if img.height == max_h and img.width == max_w
)

if len(set(h)) > 1 or len(set(w)) > 1:
warnings.warn(
"Cannot concatenate images with different size. Will resize using max width/heigh",
UserWarning,
)
max_h, max_w = max(h), max(w)
for img in data:
if img.height == max_h and img.width == max_w:
continue
Expand Down Expand Up @@ -472,6 +481,7 @@ def create_from_list(cls, data: Sequence["ImageData"]) -> "ImageData":
bounds=bounds,
band_names=band_names,
dataset_statistics=dataset_statistics,
cutline_mask=cutline_mask,
)

def as_masked(self) -> numpy.ma.MaskedArray:
Expand Down
12 changes: 10 additions & 2 deletions rio_tiler/mosaic/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class MosaicMethodBase(abc.ABC):

mosaic: Optional[numpy.ma.MaskedArray] = field(default=None, init=False)
exit_when_filled: bool = field(default=False, init=False)
cutline_mask: Optional[numpy.ndarray] = field(default=None, init=False)

@property
def is_done(self) -> bool:
Expand All @@ -25,8 +26,15 @@ def is_done(self) -> bool:
if self.mosaic is None:
return False

if self.exit_when_filled and not numpy.ma.is_masked(self.mosaic):
return True
if self.exit_when_filled:
if (
self.cutline_mask is not None
and numpy.sum(numpy.where(~self.cutline_mask, self.mosaic.mask, False))
== 0
):
return True
elif not numpy.ma.is_masked(self.mosaic):
return True

return False

Expand Down
2 changes: 2 additions & 0 deletions rio_tiler/mosaic/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def mosaic_reader(
bounds = img.bounds
band_names = img.band_names

pixel_selection.cutline_mask = img.cutline_mask

assets_used.append(asset)
pixel_selection.feed(img.array)

Expand Down
19 changes: 13 additions & 6 deletions rio_tiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,7 @@ def create_cutline(
str: WKT geometry in form of `POLYGON ((x y, x y, ...)))

"""
if "geometry" in geometry:
geometry = geometry["geometry"]

if not is_valid_geom(geometry):
raise RioTilerError("Invalid geometry")

geometry = _validate_shape_input(geometry)
geom_type = geometry["type"]

if geometry_crs:
Expand Down Expand Up @@ -635,3 +630,15 @@ def normalize_bounds(bounds: BBox) -> BBox:
max(bounds[0], bounds[2]),
max(bounds[1], bounds[3]),
)


def _validate_shape_input(shape: Dict) -> Dict:
"""Ensure input shape is valid and reduce features to geometry"""

if "geometry" in shape:
shape = shape["geometry"]

if not is_valid_geom(shape):
raise RioTilerError("Invalid geometry")

return shape
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved
Binary file added tests/fixtures/lowres.tif
Binary file not shown.
6 changes: 5 additions & 1 deletion tests/fixtures/stac.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,16 @@
"href": "http://somewhere-over-the-rainbow.io/green.tif",
"title": "green",
"file:header_size": 30000
},
},
"blue": {
"href": "http://somewhere-over-the-rainbow.io/blue.tif",
"title": "blue",
"file:header_size": 20000
},
"lowres": {
"href": "http://somewhere-over-the-rainbow.io/lowres.tif",
"title": "lowres"
},
"thumbnail": {
"href": "http://cool-sat.com/catalog/a-fake-item/thumbnail.png",
"title": "Thumbnail",
Expand Down
27 changes: 18 additions & 9 deletions tests/test_io_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_fetch_stac(httpx, s3_get):
assert stac.maxzoom == 24
assert stac.bounds
assert stac.input == STAC_PATH
assert stac.assets == ["red", "green", "blue"]
assert stac.assets == ["red", "green", "blue", "lowres"]
httpx.assert_not_called()
s3_get.assert_not_called()

Expand All @@ -55,13 +55,13 @@ def test_fetch_stac(httpx, s3_get):
assert stac.minzoom == 0
assert stac.maxzoom == 24
assert not stac.input
assert stac.assets == ["red", "green", "blue"]
assert stac.assets == ["red", "green", "blue", "lowres"]
httpx.assert_not_called()
s3_get.assert_not_called()

# Exclude red
with STACReader(STAC_PATH, exclude_assets={"red"}) as stac:
assert stac.assets == ["green", "blue"]
assert stac.assets == ["green", "blue", "lowres"]
httpx.assert_not_called()
s3_get.assert_not_called()

Expand Down Expand Up @@ -109,7 +109,7 @@ def raise_for_status(self):
httpx.get.return_value = MockResponse(f.read())

with STACReader("http://somewhereovertherainbow.io/mystac.json") as stac:
assert stac.assets == ["red", "green", "blue"]
assert stac.assets == ["red", "green", "blue", "lowres"]
httpx.get.assert_called_once()
s3_get.assert_not_called()
httpx.mock_reset()
Expand All @@ -119,7 +119,7 @@ def raise_for_status(self):
s3_get.return_value = f.read()

with STACReader("s3://somewhereovertherainbow.io/mystac.json") as stac:
assert stac.assets == ["red", "green", "blue"]
assert stac.assets == ["red", "green", "blue", "lowres"]
httpx.assert_not_called()
s3_get.assert_called_once()
assert s3_get.call_args[0] == ("somewhereovertherainbow.io", "mystac.json")
Expand Down Expand Up @@ -448,7 +448,7 @@ def test_merged_statistics_valid(rio):
with STACReader(STAC_PATH) as stac:
with pytest.warns(UserWarning):
stats = stac.merged_statistics()
assert len(stats) == 3
assert len(stats) == 4
assert isinstance(stats["red_b1"], BandStatistics)
assert stats["red_b1"]
assert stats["green_b1"]
Expand Down Expand Up @@ -603,6 +603,15 @@ def test_feature_valid(rio):
assert img.mask.shape == (118, 96)
assert img.band_names == ["green_b1*2", "green_b1", "red_b1*2"]

with pytest.warns(
UserWarning,
match="Cannot concatenate images with different size. Will resize using max width/heigh",
):
img = stac.feature(feat, assets=("blue", "lowres"))
assert img.data.shape == (2, 118, 96)
assert img.mask.shape == (118, 96)
assert img.band_names == ["blue_b1", "lowres_b1"]


def test_relative_assets():
"""Should return absolute href for assets"""
Expand Down Expand Up @@ -640,7 +649,7 @@ def raise_for_status(self):
"headers": {"Authorization": "Bearer token"},
},
) as stac:
assert stac.assets == ["red", "green", "blue"]
assert stac.assets == ["red", "green", "blue", "lowres"]
httpx.get.assert_called_once()
assert httpx.get.call_args[1]["auth"] == ("user", "pass")
assert httpx.get.call_args[1]["headers"] == {"Authorization": "Bearer token"}
Expand All @@ -653,7 +662,7 @@ def raise_for_status(self):
"headers": {"Authorization": "Bearer token"},
},
) as stac:
assert stac.assets == ["red", "green", "blue"]
assert stac.assets == ["red", "green", "blue", "lowres"]

# Check if it was cached
assert httpx.get.call_count == 1
Expand All @@ -667,7 +676,7 @@ def raise_for_status(self):
"s3://somewhereovertherainbow.io/mystac.json",
fetch_options={"request_pays": True},
) as stac:
assert stac.assets == ["red", "green", "blue"]
assert stac.assets == ["red", "green", "blue", "lowres"]
httpx.assert_not_called()
s3_get.assert_called_once()
assert s3_get.call_args[1]["request_pays"]
Expand Down