In [None]:
from torchgeo.datasets import RasterDataset, VectorDataset
import torch 

class RAMPImageDataset(RasterDataset):
    filename_glob = "*.tif"
    is_image = True
    all_bands = ("R", "G", "B")
    rgb_bands = ("R", "G", "B")

class RAMPMaskDataset(VectorDataset):
    filename_glob = "*.geojson"

    def __init__(self, paths, crs=None, **kwargs):
        super().__init__(paths=paths, crs=crs, task="instance_segmentation", **kwargs)

    def __getitem__(self, index):
        sample = super().__getitem__(index)
        m = sample["mask"]
        if m.ndim == 3:
            keep = m.view(m.shape[0], -1).sum(dim=1) > 0
            sample["mask"] = m[keep]
            sample["label"] = sample["label"][keep]
            sample["bbox_xyxy"] = sample["bbox_xyxy"][keep]
            if sample["mask"].numel() == 0:
                sample["mask"] = torch.zeros((0, m.shape[-2], m.shape[-1]), dtype=m.dtype)
                sample["label"] = sample["label"][:0]
                sample["bbox_xyxy"] = sample["bbox_xyxy"][:0]
        return sample

In [85]:
from pyproj import CRS
from pathlib import Path
image_paths, label_paths = [], []

region_path = Path('/home/krschap/data/banepa_leg')
img_path, lbl_path = region_path / "source", region_path / "labels"
if img_path.exists() and lbl_path.exists():
    image_paths.append(img_path)
    label_paths.append(lbl_path)

if not image_paths:
    raise ValueError(f"No valid regions found in {region_path}")

# target_crs = CRS.from_epsg(3857)
print("Loading images ...")
print(image_paths[0])
target_res = (0.44696360694797477, 0.4795358536102867)
images = RAMPImageDataset(paths=image_paths)
print(
    f"Loaded {len(images)} image tiles. using crs : {images.crs} with res {images.res}"
)
print("Loading labels ...")
masks = RAMPMaskDataset(paths=label_paths)
print(label_paths[0])
print(
    f"Loaded {len(masks)} mask tiles. using crs : {masks.crs} with res {masks.res}"
)


Loading images ...
/home/krschap/data/banepa_leg/source
Loaded 374 image tiles. using crs : GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST],AUTHORITY["EPSG","4326"]] with res (2.682209014892578e-06, 2.376153779362622e-06)
Loading labels ...
/home/krschap/data/banepa_leg/labels
Loaded 1 mask tiles. using crs : EPSG:4326 with res (0.0001, 0.0001)


In [86]:
whole_dataset = images & masks

Converting RAMPMaskDataset res from (0.0001, 0.0001) to (2.682209014892578e-06, 2.376153779362622e-06)


In [87]:
whole_dataset.crs

<Geographic 2D CRS: EPSG:4326>
Name: WGS 84
Axis Info [ellipsoidal]:
- Lat[north]: Geodetic latitude (degree)
- Lon[east]: Geodetic longitude (degree)
Area of Use:
- undefined
Datum: World Geodetic System 1984
- Ellipsoid: WGS 84
- Prime Meridian: Greenwich

In [88]:
from torchgeo.samplers import RandomGeoSampler, Units


sampler = RandomGeoSampler(
    whole_dataset,
    size=256,
    units=Units.PIXELS,
)
print(len(sampler))

60


In [89]:
import pandas as pd 
import numpy as np

for i, bbox in enumerate(sampler):
    try:
        s = whole_dataset[bbox]
    except Exception as e:
        print("fail at", i, bbox, e)
        break
    m = s["mask"]
    if m.ndim == 3 and m.shape[0] != 0:
        ids = np.unique(m.numpy())
        ids = ids[ids != 0]
        if m.shape[0] != len(ids):
            print("id mismatch at", i, bbox, "mask planes", m.shape[0], "ids", len(ids))
            break

id mismatch at 0 (slice(85.52238464355469, 85.5230712890625, None), slice(27.63730702015522, 27.637915315522736, None), slice(Timestamp('1677-09-21 00:12:43.145224193'), Timestamp('2262-04-11 23:47:16.854775807'), None)) mask planes 10 ids 1
