diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 8934b32620f..01ed9d6162d 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -125,6 +125,11 @@ SSL4EO .. autoclass:: SSL4EOS12DataModule +SustainBench Crop Yield +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: SustainBenchCropYieldDataModule + Tropical Cyclone ^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 53b9642a860..93317f4270e 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -324,6 +324,11 @@ SSL4EO .. autoclass:: SSL4EOS12 +SustainBench Crop Yield +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: SustainBenchCropYield + Tropical Cyclone ^^^^^^^^^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 2d49718d81b..0bd61c05115 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -30,6 +30,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `So2Sat`_,C,Sentinel-1/2,"400,673",17,32x32,10,"SAR, MSI" `SpaceNet`_,I,WorldView-2/3 Planet Lab Dove,"1,889--28,728",2,102--900,0.5--4,MSI `SSL4EO`_,T,Sentinel-1/2,1M,-,264x264,10,"SAR, MSI" +`SustainBench Crop Yield`_,R,MODIS,11k,-,32x32,-,MSI `Tropical Cyclone`_,R,GOES 8--16,"108,110",-,256x256,4K--8K,MSI `UC Merced`_,C,USGS National Map,"21,000",21,256x256,0.3,RGB `USAVars`_,R,NAIP Aerial,100K,-,-,4,"RGB, NIR" diff --git a/tests/conf/sustainbench_crop_yield.yaml b/tests/conf/sustainbench_crop_yield.yaml new file mode 100644 index 00000000000..60903ea7d4c --- /dev/null +++ b/tests/conf/sustainbench_crop_yield.yaml @@ -0,0 +1,14 @@ +experiment: + task: "sustainbench_crop_yield" + module: + model: "resnet18" + weights: null + num_outputs: 1 + in_channels: 9 + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + datamodule: + root: "tests/data/sustainbench_crop_yield" + download: true + batch_size: 1 + num_workers: 0 diff --git a/tests/data/sustainbench_crop_yield/data.py b/tests/data/sustainbench_crop_yield/data.py new file mode 100644 index 00000000000..46b2e653ac6 --- /dev/null +++ b/tests/data/sustainbench_crop_yield/data.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np + +SIZE = 32 +NUM_SAMPLES = 3 +NUM_BANDS = 9 + +np.random.seed(0) + +countries = ["argentina", "brazil", "usa"] +splits = ["train", "dev", "test"] + +root_dir = "soybeans" + + +def create_files(path: str, split: str) -> None: + hist_img = np.random.random(size=(NUM_SAMPLES, SIZE, SIZE, NUM_BANDS)) + np.savez(os.path.join(path, f"{split}_hists.npz"), data=hist_img) + + target = np.random.random(size=(NUM_SAMPLES, 1)) + np.savez(os.path.join(path, f"{split}_yields.npz"), data=target) + + ndvi = np.random.random(size=(NUM_SAMPLES, SIZE)) + np.savez(os.path.join(path, f"{split}_ndvi.npz"), data=ndvi) + + year = np.array(["2009"] * NUM_SAMPLES, dtype=" None: + shutil.copy(url, root) + + +class TestSustainBenchCropYield: + @pytest.fixture(params=["train", "dev", "test"]) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> SustainBenchCropYield: + monkeypatch.setattr( + torchgeo.datasets.sustainbench_crop_yield, "download_url", download_url + ) + + md5 = "7a5591794e14dd73d2b747cd2244acbc" + monkeypatch.setattr(SustainBenchCropYield, "md5", md5) + url = os.path.join("tests", "data", "sustainbench_crop_yield", "soybeans.zip") + monkeypatch.setattr(SustainBenchCropYield, "url", url) + monkeypatch.setattr(plt, "show", lambda *args: None) + root = str(tmp_path) + split = request.param + countries = ["argentina", "brazil", "usa"] + transforms = nn.Identity() + return SustainBenchCropYield( + root, split, countries, transforms, download=True, checksum=True + ) + + def test_already_extracted(self, dataset: SustainBenchCropYield) -> None: + SustainBenchCropYield(root=dataset.root, download=True) + + def test_already_downloaded(self, tmp_path: Path) -> None: + pathname = os.path.join( + "tests", "data", "sustainbench_crop_yield", "soybeans.zip" + ) + root = str(tmp_path) + shutil.copy(pathname, root) + SustainBenchCropYield(root) + + @pytest.mark.parametrize("index", [0, 1, 2]) + def test_getitem(self, dataset: SustainBenchCropYield, index: int) -> None: + x = dataset[index] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert isinstance(x["year"], torch.Tensor) + assert isinstance(x["ndvi"], torch.Tensor) + assert x["image"].shape == (9, 32, 32) + + def test_len(self, dataset: SustainBenchCropYield) -> None: + assert len(dataset) == len(dataset.countries) * 3 + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + SustainBenchCropYield(split="foo") + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in"): + SustainBenchCropYield(str(tmp_path)) + + def test_plot(self, dataset: SustainBenchCropYield) -> None: + dataset.plot(dataset[0], suptitle="Test") + plt.close() + + sample = dataset[0] + sample["prediction"] = sample["label"] + dataset.plot(sample) + plt.close() diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index f311f38e28e..f4210a7cfa9 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -19,6 +19,7 @@ COWCCountingDataModule, MisconfigurationException, SKIPPDDataModule, + SustainBenchCropYieldDataModule, TropicalCycloneDataModule, ) from torchgeo.datasets import TropicalCyclone @@ -29,8 +30,8 @@ class RegressionTestModel(ClassificationTestModel): - def __init__(self, **kwargs: Any) -> None: - super().__init__(in_chans=3, num_classes=1) + def __init__(self, in_chans: int = 3, num_classes: int = 1, **kwargs: Any) -> None: + super().__init__(in_chans=in_chans, num_classes=num_classes) class PredictRegressionDataModule(TropicalCycloneDataModule): @@ -53,6 +54,7 @@ class TestRegressionTask: [ ("cowc_counting", COWCCountingDataModule), ("cyclone", TropicalCycloneDataModule), + ("sustainbench_crop_yield", SustainBenchCropYieldDataModule), ("skippd", SKIPPDDataModule), ], ) @@ -71,7 +73,10 @@ def test_trainer( model_kwargs = conf_dict["module"] model = RegressionTask(**model_kwargs) - model.model = RegressionTestModel() + model.model = RegressionTestModel( + in_chans=model_kwargs["in_channels"], + num_classes=model_kwargs["num_outputs"], + ) # Instantiate trainer trainer = Trainer( @@ -80,6 +85,7 @@ def test_trainer( log_every_n_steps=1, max_epochs=1, ) + trainer.fit(model=model, datamodule=datamodule) try: trainer.test(model=model, datamodule=datamodule) diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index edced09bc90..4265285235d 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -29,6 +29,7 @@ from .so2sat import So2SatDataModule from .spacenet import SpaceNet1DataModule from .ssl4eo import SSL4EOS12DataModule +from .sustainbench_crop_yield import SustainBenchCropYieldDataModule from .ucmerced import UCMercedDataModule from .usavars import USAVarsDataModule from .utils import MisconfigurationException @@ -63,6 +64,7 @@ "So2SatDataModule", "SpaceNet1DataModule", "SSL4EOS12DataModule", + "SustainBenchCropYieldDataModule", "TropicalCycloneDataModule", "UCMercedDataModule", "USAVarsDataModule", diff --git a/torchgeo/datamodules/sustainbench_crop_yield.py b/torchgeo/datamodules/sustainbench_crop_yield.py new file mode 100644 index 00000000000..b3283bd36e7 --- /dev/null +++ b/torchgeo/datamodules/sustainbench_crop_yield.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SustainBench Crop Yield datamodule.""" + +from typing import Any + +from ..datasets import SustainBenchCropYield +from .geo import NonGeoDataModule + + +class SustainBenchCropYieldDataModule(NonGeoDataModule): + """LightningDataModule for SustainBench Crop Yield dataset. + + .. versionadded:: 0.5 + """ + + def __init__( + self, batch_size: int = 32, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a new SustainBenchCropYieldDataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.SustainBenchCropYield`. + """ + super().__init__(SustainBenchCropYield, batch_size, num_workers, **kwargs) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit"]: + self.train_dataset = SustainBenchCropYield(split="train", **self.kwargs) + if stage in ["fit", "validate"]: + self.val_dataset = SustainBenchCropYield(split="dev", **self.kwargs) + if stage in ["test"]: + self.test_dataset = SustainBenchCropYield(split="test", **self.kwargs) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index d3eb9882593..8761bdc0894 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -104,6 +104,7 @@ time_series_split, ) from .ssl4eo import SSL4EOS12 +from .sustainbench_crop_yield import SustainBenchCropYield from .ucmerced import UCMerced from .usavars import USAVars from .utils import ( @@ -207,6 +208,7 @@ "SpaceNet6", "SpaceNet7", "SSL4EOS12", + "SustainBenchCropYield", "TropicalCyclone", "UCMerced", "USAVars", diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py new file mode 100644 index 00000000000..acb44a5ce29 --- /dev/null +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -0,0 +1,274 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SustainBench Crop Yield dataset.""" + +import os +from typing import Any, Callable, Optional + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch import Tensor + +from .geo import NonGeoDataset +from .utils import download_url, extract_archive + + +class SustainBenchCropYield(NonGeoDataset): + """SustainBench Crop Yield Dataset. + + This dataset contains MODIS band histograms and soybean yield + estimates for selected counties in the USA, Argentina and Brazil. + The dataset is part of the + `SustainBench `_ + datasets for tackling the UN Sustainable Development Goals (SDGs). + + Dataset Format: + + * .npz files of stacked samples + + Dataset Features: + + * input histogram of 7 surface reflectance and 2 surface temperature + bands from MODIS pixel values in 32 ranges across 32 timesteps + resulting in 32x32x9 input images + * regression target value of soybean yield in metric tonnes per + harvested hectare + + If you use this dataset in your research, please cite: + + * https://doi.org/10.1145/3209811.3212707 + * https://doi.org/10.1609/aaai.v31i1.11172 + + .. versionadded:: 0.5 + """ # noqa: E501 + + valid_countries = ["usa", "brazil", "argentina"] + + md5 = "c2794e59512c897d9bea77b112848122" + + url = "https://drive.google.com/file/d/1odwkI1hiE5rMZ4VfM0hOXzlFR4NbhrfU/view?usp=share_link" # noqa: E501 + + dir = "soybeans" + + valid_splits = ["train", "dev", "test"] + + def __init__( + self, + root: str = "data", + split: str = "train", + countries: list[str] = ["usa"], + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train", "dev", or "test" + countries: which countries to include in the dataset + transforms: a function/transform that takes an input sample + and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 after downloading files (may be slow) + + Raises: + AssertionError: if ``countries`` contains invalid countries or if ``split`` + is invalid + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + assert set(countries).issubset( + self.valid_countries + ), f"Please choose a subset of these valid countried: {self.valid_countries}." + self.countries = countries + + assert ( + split in self.valid_splits + ), f"Pleas choose one of these valid data splits {self.valid_splits}." + self.split = split + + self.root = root + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + self.collection = self.retrieve_collection() + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.collection) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + input_file_path, sample_idx = self.collection[index] + + sample: dict[str, Tensor] = { + "image": self._load_image(input_file_path, sample_idx) + } + sample.update(self._load_features(input_file_path, sample_idx)) + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _load_image(self, path: str, sample_idx: int) -> Tensor: + """Load input image. + + Args: + path: path to input npz collection + sample_idx: what sample to index from the npz collection + + Returns: + input image as tensor + """ + arr = np.load(path)["data"][sample_idx] + # return [channel, height, width] + return torch.from_numpy(arr).permute(2, 0, 1).to(torch.float32) + + def _load_features(self, path: str, sample_idx: int) -> dict[str, Tensor]: + """Load features value. + + Args: + path: path to image npz collection + sample_idx: what sample to index from the npz collection + + Returns: + target regression value + """ + target_file_path = path.replace("_hists", "_yields") + target = np.load(target_file_path)["data"][sample_idx] + + years_file_path = path.replace("_hists", "_years") + year = int(np.load(years_file_path)["data"][sample_idx]) + + ndvi_file_path = path.replace("_hists", "_ndvi") + ndvi = np.load(ndvi_file_path)["data"][sample_idx] + + features = { + "label": torch.tensor(target).to(torch.float32), + "year": torch.tensor(year), + "ndvi": torch.from_numpy(ndvi).to(dtype=torch.float32), + } + return features + + def retrieve_collection(self) -> list[tuple[str, int]]: + """Retrieve the collection. + + Returns: + path and index to dataset samples + """ + collection = [] + for country in self.countries: + file_path = os.path.join( + self.root, self.dir, country, f"{self.split}_hists.npz" + ) + npz_file = np.load(file_path) + num_data_points = npz_file["data"].shape[0] + for idx in range(num_data_points): + collection.append((file_path, idx)) + + return collection + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + # Check if the extracted files already exist + pathname = os.path.join(self.root, self.dir) + if os.path.exists(pathname): + return + + # Check if the zip files have already been downloaded + pathname = os.path.join(self.root, self.dir) + ".zip" + if os.path.exists(pathname): + self._extract() + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automatically download the dataset." + ) + + # Download the dataset + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset and extract it. + + Raises: + RuntimeError: if download doesn't work correctly or checksums don't match + """ + download_url( + self.url, + self.root, + filename=self.dir, + md5=self.md5 if self.checksum else None, + ) + self._extract() + + def _extract(self) -> None: + """Extract the dataset.""" + zipfile_path = os.path.join(self.root, self.dir) + ".zip" + extract_archive(zipfile_path, self.root) + + def plot( + self, + sample: dict[str, Any], + band_idx: int = 0, + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample return by :meth:`__getitem__` + band_idx: which of the nine histograms to index + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + + """ + image, label = sample["image"], sample["label"].item() + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction = sample["prediction"].item() + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + + ax.imshow(image.permute(1, 2, 0)[:, :, band_idx]) + ax.axis("off") + + if show_titles: + title = f"Label: {label:.3f}" + if showing_predictions: + title += f"\nPrediction: {prediction:.3f}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig