Nodata-aware, geographically-stratified patch sampler for satellite raster imagery.
torchgeo.samplers.RandomGeoSampler and GridGeoSampler generate patch coordinates without any awareness of nodata pixels, cloud masks, or spatial distribution. Training on mostly-nodata patches wastes GPU compute and degrades model quality. Selecting correlated adjacent patches inflates apparent accuracy on spatially autocorrelated imagery.
rs-patch-sampler fills the gap with three composable mechanisms:
- Nodata filtering — each candidate patch is evaluated via a cheap windowed read; patches exceeding
max_nodata_fractionare rejected before any model sees them. - Geographic spread — when
target_nis set, patches are selected to maximise spatial spread (farthest-point strategy), reducing spatial autocorrelation in training data. - Class stratification — when a label raster is provided, class counts per patch guide up-weighting of minority classes during final selection.
pip install rs-patch-samplerRequires Python ≥ 3.10. Dependencies: rasterio, numpy, shapely, pyproj, click, tqdm.
from rs_patch_sampler import PatchSampler, SamplerConfig
cfg = SamplerConfig(
patch_size=256, # pixels
max_nodata_fraction=0.05,# reject if >5% nodata
target_n=500, # keep 500 spatially-spread patches
seed=42,
)
sampler = PatchSampler("sentinel2_scene.tif", cfg)
patches = sampler.sample(show_progress=True)
print(f"Accepted {len(patches)} patches")
sampler.to_geojson(patches, "patches.geojson")Each PatchStats object exposes:
| Field | Type | Description |
|---|---|---|
col_off |
int |
Column offset of patch origin |
row_off |
int |
Row offset of patch origin |
width |
int |
Actual patch width (clamped at edges) |
height |
int |
Actual patch height (clamped at edges) |
nodata_fraction |
float |
Fraction of nodata pixels in patch |
valid_pixels |
int |
Number of non-nodata pixels |
class_counts |
dict[int,int] |
Class pixel counts (when label raster given) |
.window |
rasterio.Window |
Ready-to-use rasterio windowed read window |
# Basic: 256×256 patches, max 5% nodata
rs-patch-sampler scene.tif --patch-size 256 --max-nodata 0.05
# Limit to 500 spatially-spread patches, write GeoJSON index
rs-patch-sampler scene.tif --n 500 --out-geojson patches.geojson
# With label raster for class statistics
rs-patch-sampler scene.tif --label labels.tif --n 200 --out-json patches.json --verboseFull CLI help: rs-patch-sampler --help
Candidate patches are generated on a regular grid (stride defaults to patch_size). For each candidate, the sampler performs a windowed read of the first raster band — O(patch_size²) pixels, not the full scene. A pixel is marked nodata if:
- The raster declares a nodata value and the pixel matches it, OR
- Any band among the first four has the nodata value (union mask for multi-band rasters), OR
- No nodata is declared and the pixel equals 0 (common sentinel for uint data).
Patches where nodata_count / total_pixels > max_nodata_fraction are rejected.
When target_n is set and fewer candidates than target_n exist after nodata filtering, all valid patches are returned. When more candidates than target_n exist, the greedy farthest-point algorithm selects target_n patches:
- Seed with a random patch.
- Iteratively add the candidate that maximises minimum distance to already-selected patches.
This is O(target_n × |candidates|) — acceptable for typical grid sizes (< 100k patches). The spatial spread is measured post-hoc via a lightweight Moran's I approximation on a rook-contiguous grid.
When label_path is given, class pixel counts are computed for each accepted patch via a second windowed read on the label raster. If class_weights are provided, patches are down/up-weighted by their dominant class during a final probabilistic selection step.
score = sampler.spread_score() # float in [0, 1], higher = more dispersed
print(f"Geographic spread: {score:.3f}")The spread score wraps Moran's I: score = (1 - I) / 2. A value near 1.0 means patches are maximally dispersed; near 0.5 means random; near 0.0 means clustered.
import rasterio
from torch.utils.data import DataLoader
from torchgeo.datasets import RasterDataset
# Get patch windows from rs-patch-sampler
sampler = PatchSampler("sentinel2.tif", SamplerConfig(patch_size=256, target_n=1000))
patches = sampler.sample()
# Use windows directly with rasterio
with rasterio.open("sentinel2.tif") as src:
for p in patches:
data = src.read(window=p.window)
# feed to your model...| Parameter | Default | Description |
|---|---|---|
patch_size |
256 |
Side length in pixels |
max_nodata_fraction |
0.05 |
Rejection threshold (0–1) |
stride |
patch_size |
Stride between candidate origins |
target_n |
None |
Max patches (spread selection applied) |
label_path |
None |
Label raster for class stratification |
class_weights |
None |
{class_id: weight} for stratified selection |
nodata_value |
None |
Override raster metadata nodata |
seed |
42 |
Random seed for reproducibility |
On a 10980×10980 Sentinel-2 scene (3 bands, uint16, EPSG:32632):
| Configuration | Candidates | Accepted | Time |
|---|---|---|---|
| patch_size=256, stride=256, max_nodata=0.05 | 1764 | ~1400 | ~1.2s |
| patch_size=256, stride=256, target_n=200 | 1764 | 200 | ~1.5s |
| patch_size=64, stride=64, target_n=500 | 28900 | 500 | ~8s |
Benchmarks run on a single CPU core. No GPU required.
git clone https://github.com/daudee215/rs-patch-sampler
cd rs-patch-sampler
pip install -e ".[dev]"
pytest tests/ -v
ruff check src/
mypy src/- torchgeo#1330: "How to avoid nodata-only patches" — torchgeo/torchgeo#1330
- SRMF (arXiv 2504.19839): motivates smarter sampling for long-tail satellite data
- Terrain-Informed SSL (arXiv 2311.01188): highlights nodata impact on LiDAR-derived rasters
- torchgeo#1047: SAR utility discussion motivating geographic spread awareness
MIT. See LICENSE.