Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CDL/NLCD/SSL4EO: allow selection of classes #1392

Merged
merged 12 commits into from
Jun 4, 2023
4 changes: 2 additions & 2 deletions tests/conf/ssl4eo_l_benchmark_cdl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module:
datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
root: "tests/data/ssl4eo_benchmark_landsat"
input_sensor: "tm_toa"
mask_product: "cdl"
sensor: "tm_toa"
product: "cdl"
batch_size: 2
num_workers: 0
4 changes: 2 additions & 2 deletions tests/conf/ssl4eo_l_benchmark_nlcd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module:
datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
root: "tests/data/ssl4eo_benchmark_landsat"
input_sensor: "etm_sr"
mask_product: "nlcd"
sensor: "etm_sr"
product: "nlcd"
batch_size: 2
num_workers: 0
Binary file added tests/data/cdl/2020_30m_cdls/2020_30m_cdls.tif
Binary file not shown.
Binary file not shown.
Binary file added tests/data/cdl/2021_30m_cdls/2021_30m_cdls.tif
Binary file not shown.
Binary file not shown.
Empty file modified tests/data/fire_risk/data.py
100644 → 100755
Empty file.
Empty file modified tests/data/nlcd/data.py
100644 → 100755
Empty file.
Empty file modified tests/data/ref_cloud_cover_detection_challenge_v1/data.py
100644 → 100755
Empty file.
Empty file modified tests/data/skippd/data.py
100644 → 100755
Empty file.
Empty file modified tests/data/spacenet/data.py
100644 → 100755
Empty file.
Empty file modified tests/data/ssl4eo_benchmark_landsat/data.py
100644 → 100755
Empty file.
Empty file modified tests/data/sustainbench_crop_yield/data.py
100644 → 100755
Empty file.
Empty file modified tests/data/vhr10/data.py
100644 → 100755
Empty file.
Empty file modified tests/data/western_usa_live_fuel_moisture/data.py
100644 → 100755
Empty file.
15 changes: 15 additions & 0 deletions tests/datasets/test_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ def test_getitem(self, dataset: CDL) -> None:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)

def test_classes(self) -> None:
root = os.path.join("tests", "data", "cdl")
classes = list(CDL.cmap.keys())[:5]
ds = CDL(root, years=[2021], classes=classes)
sample = ds[ds.bounds]
mask = sample["mask"]
assert mask.max() < len(classes)

def test_and(self, dataset: CDL) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
Expand Down Expand Up @@ -82,6 +90,13 @@ def test_invalid_year(self, tmp_path: Path) -> None:
):
CDL(str(tmp_path), years=[1996])

def test_invalid_classes(self) -> None:
with pytest.raises(AssertionError):
CDL(classes=[-1])

with pytest.raises(AssertionError):
CDL(classes=[11])

def test_plot(self, dataset: CDL) -> None:
query = dataset.bounds
x = dataset[query]
Expand Down
15 changes: 15 additions & 0 deletions tests/datasets/test_nlcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def test_getitem(self, dataset: NLCD) -> None:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)

def test_classes(self) -> None:
root = os.path.join("tests", "data", "nlcd")
classes = list(NLCD.cmap.keys())[:5]
ds = NLCD(root, years=[2019], classes=classes)
sample = ds[ds.bounds]
mask = sample["mask"]
assert mask.max() < len(classes)

def test_and(self, dataset: NLCD) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
Expand All @@ -78,6 +86,13 @@ def test_invalid_year(self, tmp_path: Path) -> None:
):
NLCD(str(tmp_path), years=[1996])

def test_invalid_classes(self) -> None:
with pytest.raises(AssertionError):
NLCD(classes=[-1])

with pytest.raises(AssertionError):
NLCD(classes=[11])

def test_plot(self, dataset: NLCD) -> None:
query = dataset.bounds
x = dataset[query]
Expand Down
36 changes: 26 additions & 10 deletions tests/datasets/test_ssl4eo_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data import ConcatDataset

import torchgeo.datasets.utils
from torchgeo.datasets import SSL4EOLBenchmark
from torchgeo.datasets import CDL, NLCD, RasterDataset, SSL4EOLBenchmark


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
Expand All @@ -42,7 +42,7 @@ def dataset(
url = os.path.join("tests", "data", "ssl4eo_benchmark_landsat", "{}.tar.gz")
monkeypatch.setattr(SSL4EOLBenchmark, "url", url)

input_sensor, mask_product, split = request.param
sensor, product, split = request.param
monkeypatch.setattr(
SSL4EOLBenchmark, "split_percentages", [1 / 3, 1 / 3, 1 / 3]
)
Expand Down Expand Up @@ -75,8 +75,8 @@ def dataset(
transforms = nn.Identity()
return SSL4EOLBenchmark(
root=root,
input_sensor=input_sensor,
mask_product=mask_product,
sensor=sensor,
product=product,
split=split,
transforms=transforms,
download=True,
Expand All @@ -89,17 +89,33 @@ def test_getitem(self, dataset: SSL4EOLBenchmark) -> None:
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

@pytest.mark.parametrize("product,base_class", [("nlcd", NLCD), ("cdl", CDL)])
def test_classes(self, product: str, base_class: RasterDataset) -> None:
root = os.path.join("tests", "data", "ssl4eo_benchmark_landsat")
classes = list(base_class.cmap.keys())[:5]
ds = SSL4EOLBenchmark(root, product=product, classes=classes)
sample = ds[0]
mask = sample["mask"]
assert mask.max() < len(classes)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(split="foo")

def test_invalid_input_sensor(self) -> None:
def test_invalid_sensor(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(sensor="foo")

def test_invalid_product(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(product="foo")

def test_invalid_classes(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(input_sensor="foo")
SSL4EOLBenchmark(classes=[-1])

def test_invalid_mask_product(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(mask_product="foo")
SSL4EOLBenchmark(classes=[11])

def test_add(self, dataset: SSL4EOLBenchmark) -> None:
ds = dataset + dataset
Expand All @@ -108,8 +124,8 @@ def test_add(self, dataset: SSL4EOLBenchmark) -> None:
def test_already_extracted(self, dataset: SSL4EOLBenchmark) -> None:
SSL4EOLBenchmark(
root=dataset.root,
input_sensor=dataset.input_sensor,
mask_product=dataset.mask_product,
sensor=dataset.sensor,
product=dataset.product,
download=True,
)

Expand Down
Loading