Skip to content

daudee215/rs-patch-sampler

Repository files navigation

rs-patch-sampler

Nodata-aware, geographically-stratified patch sampler for satellite raster imagery.

PyPI version CI Python 3.10+ License: MIT


The problem

torchgeo.samplers.RandomGeoSampler and GridGeoSampler generate patch coordinates without any awareness of nodata pixels, cloud masks, or spatial distribution. Training on mostly-nodata patches wastes GPU compute and degrades model quality. Selecting correlated adjacent patches inflates apparent accuracy on spatially autocorrelated imagery.

rs-patch-sampler fills the gap with three composable mechanisms:

  1. Nodata filtering — each candidate patch is evaluated via a cheap windowed read; patches exceeding max_nodata_fraction are rejected before any model sees them.
  2. Geographic spread — when target_n is set, patches are selected to maximise spatial spread (farthest-point strategy), reducing spatial autocorrelation in training data.
  3. Class stratification — when a label raster is provided, class counts per patch guide up-weighting of minority classes during final selection.

Installation

pip install rs-patch-sampler

Requires Python ≥ 3.10. Dependencies: rasterio, numpy, shapely, pyproj, click, tqdm.


Quickstart

Python API

from rs_patch_sampler import PatchSampler, SamplerConfig

cfg = SamplerConfig(
    patch_size=256,          # pixels
    max_nodata_fraction=0.05,# reject if >5% nodata
    target_n=500,            # keep 500 spatially-spread patches
    seed=42,
)
sampler = PatchSampler("sentinel2_scene.tif", cfg)
patches = sampler.sample(show_progress=True)

print(f"Accepted {len(patches)} patches")
sampler.to_geojson(patches, "patches.geojson")

Each PatchStats object exposes:

Field Type Description
col_off int Column offset of patch origin
row_off int Row offset of patch origin
width int Actual patch width (clamped at edges)
height int Actual patch height (clamped at edges)
nodata_fraction float Fraction of nodata pixels in patch
valid_pixels int Number of non-nodata pixels
class_counts dict[int,int] Class pixel counts (when label raster given)
.window rasterio.Window Ready-to-use rasterio windowed read window

CLI

# Basic: 256×256 patches, max 5% nodata
rs-patch-sampler scene.tif --patch-size 256 --max-nodata 0.05

# Limit to 500 spatially-spread patches, write GeoJSON index
rs-patch-sampler scene.tif --n 500 --out-geojson patches.geojson

# With label raster for class statistics
rs-patch-sampler scene.tif --label labels.tif --n 200 --out-json patches.json --verbose

Full CLI help: rs-patch-sampler --help


How it works

Phase 1: Nodata filtering

Candidate patches are generated on a regular grid (stride defaults to patch_size). For each candidate, the sampler performs a windowed read of the first raster band — O(patch_size²) pixels, not the full scene. A pixel is marked nodata if:

  • The raster declares a nodata value and the pixel matches it, OR
  • Any band among the first four has the nodata value (union mask for multi-band rasters), OR
  • No nodata is declared and the pixel equals 0 (common sentinel for uint data).

Patches where nodata_count / total_pixels > max_nodata_fraction are rejected.

Phase 2: Geographic spread (farthest-point selection)

When target_n is set and fewer candidates than target_n exist after nodata filtering, all valid patches are returned. When more candidates than target_n exist, the greedy farthest-point algorithm selects target_n patches:

  1. Seed with a random patch.
  2. Iteratively add the candidate that maximises minimum distance to already-selected patches.

This is O(target_n × |candidates|) — acceptable for typical grid sizes (< 100k patches). The spatial spread is measured post-hoc via a lightweight Moran's I approximation on a rook-contiguous grid.

Phase 3: Class stratification (optional)

When label_path is given, class pixel counts are computed for each accepted patch via a second windowed read on the label raster. If class_weights are provided, patches are down/up-weighted by their dominant class during a final probabilistic selection step.


Measuring spread quality

score = sampler.spread_score()  # float in [0, 1], higher = more dispersed
print(f"Geographic spread: {score:.3f}")

The spread score wraps Moran's I: score = (1 - I) / 2. A value near 1.0 means patches are maximally dispersed; near 0.5 means random; near 0.0 means clustered.


Integration with torchgeo

import rasterio
from torch.utils.data import DataLoader
from torchgeo.datasets import RasterDataset

# Get patch windows from rs-patch-sampler
sampler = PatchSampler("sentinel2.tif", SamplerConfig(patch_size=256, target_n=1000))
patches = sampler.sample()

# Use windows directly with rasterio
with rasterio.open("sentinel2.tif") as src:
    for p in patches:
        data = src.read(window=p.window)
        # feed to your model...

Configuration reference

Parameter Default Description
patch_size 256 Side length in pixels
max_nodata_fraction 0.05 Rejection threshold (0–1)
stride patch_size Stride between candidate origins
target_n None Max patches (spread selection applied)
label_path None Label raster for class stratification
class_weights None {class_id: weight} for stratified selection
nodata_value None Override raster metadata nodata
seed 42 Random seed for reproducibility

Benchmarks

On a 10980×10980 Sentinel-2 scene (3 bands, uint16, EPSG:32632):

Configuration Candidates Accepted Time
patch_size=256, stride=256, max_nodata=0.05 1764 ~1400 ~1.2s
patch_size=256, stride=256, target_n=200 1764 200 ~1.5s
patch_size=64, stride=64, target_n=500 28900 500 ~8s

Benchmarks run on a single CPU core. No GPU required.


Development

git clone https://github.com/daudee215/rs-patch-sampler
cd rs-patch-sampler
pip install -e ".[dev]"
pytest tests/ -v
ruff check src/
mypy src/

References

  • torchgeo#1330: "How to avoid nodata-only patches" — torchgeo/torchgeo#1330
  • SRMF (arXiv 2504.19839): motivates smarter sampling for long-tail satellite data
  • Terrain-Informed SSL (arXiv 2311.01188): highlights nodata impact on LiDAR-derived rasters
  • torchgeo#1047: SAR utility discussion motivating geographic spread awareness

License

MIT. See LICENSE.

About

Nodata-aware, geographically-stratified patch sampler for satellite raster imagery

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors