Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SustainBench Crop Yield Dataset #1253

Merged
merged 15 commits into from
Apr 20, 2023
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ SSL4EO

.. autoclass:: SSL4EOS12DataModule

SustainbenchCropYieldDataModule
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

.. autoclass:: SustainbenchCropYieldDataModule

Tropical Cyclone
^^^^^^^^^^^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ SSL4EO

.. autoclass:: SSL4EOS12

SustainbenchCropYieldPrediction
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: SustainbenchCropYieldPrediction

Tropical Cyclone
^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`So2Sat`_,C,Sentinel-1/2,"400,673",17,32x32,10,"SAR, MSI"
`SpaceNet`_,I,WorldView-2/3 Planet Lab Dove,"1,889--28,728",2,102--900,0.5--4,MSI
`SSL4EO`_,T,Sentinel-1/2,1M,-,264x264,10,"SAR, MSI"
`SustainbenchCropYieldPrediction`_,R,MODIS,11k,-,32x32,-,MSI
`Tropical Cyclone`_,R,GOES 8--16,"108,110",-,256x256,4K--8K,MSI
`UC Merced`_,C,USGS National Map,"21,000",21,256x256,0.3,RGB
`USAVars`_,R,NAIP Aerial,100K,-,-,4,"RGB, NIR"
Expand Down
60 changes: 60 additions & 0 deletions tests/data/sustainbench_crop_yield_prediction/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil

import numpy as np

SIZE = 32
NUM_SAMPLES = 3
NUM_BANDS = 9

np.random.seed(0)

countries = ["argentina", "brazil", "usa"]
splits = ["train", "dev", "test"]

root_dir = "soybeans"


def create_files(path: str, split: str) -> None:
hist_img = np.random.random(size=(NUM_SAMPLES, SIZE, SIZE, NUM_BANDS))
np.savez(os.path.join(path, f"{split}_hists.npz"), data=hist_img)

target = np.random.random(size=(NUM_SAMPLES, 1))
np.savez(os.path.join(path, f"{split}_yields.npz"), data=target)

ndvi = np.random.random(size=(NUM_SAMPLES, SIZE))
np.savez(os.path.join(path, f"{split}_ndvi.npz"), data=ndvi)

year = np.array(["2009"] * NUM_SAMPLES, dtype="<U4")
np.savez(os.path.join(path, f"{split}_years.npz"), data=year)


if __name__ == "__main__":
# Remove old data
if os.path.isdir(root_dir):
shutil.rmtree(root_dir)

os.makedirs(root_dir)

for country in countries:
dir = os.path.join(root_dir, country)
os.makedirs(dir)

for split in splits:
create_files(dir, split)

filename = root_dir + ".zip"

# Compress data
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", root_dir)

# Compute checksums
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{filename}: {md5}")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
82 changes: 82 additions & 0 deletions tests/datasets/test_sustainbench_crop_yield_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datasets import SustainBenchCropYieldPrediction


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestSustainbenchCropYieldPrediction:
@pytest.fixture(params=["train", "dev", "test"])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> SustainBenchCropYieldPrediction:
monkeypatch.setattr(
torchgeo.datasets.sustainbench_crop_yield_prediction,
"download_url",
download_url,
)

md5 = "7a5591794e14dd73d2b747cd2244acbc"
monkeypatch.setattr(SustainBenchCropYieldPrediction, "md5", md5)
url = os.path.join(
"tests", "data", "sustainbench_crop_yield_prediction", "soybeans.zip"
)
monkeypatch.setattr(SustainBenchCropYieldPrediction, "url", url)
monkeypatch.setattr(plt, "show", lambda *args: None)
root = str(tmp_path)
split = request.param
countries = ["argentina", "brazil", "usa"]
transforms = nn.Identity()
return SustainBenchCropYieldPrediction(
root, split, countries, transforms, download=True, checksum=True
)

@pytest.mark.parametrize("index", [0, 1, 2])
def test_getitem(
self, dataset: SustainBenchCropYieldPrediction, index: int
) -> None:
x = dataset[index]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert isinstance(x["year"], torch.Tensor)
assert isinstance(x["ndvi"], torch.Tensor)
assert x["image"].shape == (9, 32, 32)

def test_len(self, dataset: SustainBenchCropYieldPrediction) -> None:
assert len(dataset) == len(dataset.countries) * 3

def test_already_downloaded(self, dataset: SustainBenchCropYieldPrediction) -> None:
SustainBenchCropYieldPrediction(root=dataset.root, download=True)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
SustainBenchCropYieldPrediction(split="foo")

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
SustainBenchCropYieldPrediction(str(tmp_path))

def test_plot(self, dataset: SustainBenchCropYieldPrediction) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()

sample = dataset[0]
sample["prediction"] = sample["label"]
dataset.plot(sample)
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
from .ssl4eo import SSL4EOS12DataModule
from .sustainbench_crop_yield_prediction import SustainbenchCropYieldDataModule
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
from .utils import MisconfigurationException
Expand Down Expand Up @@ -61,6 +62,7 @@
"So2SatDataModule",
"SpaceNet1DataModule",
"SSL4EOS12DataModule",
"SustainbenchCropYieldDataModule",
"TropicalCycloneDataModule",
"UCMercedDataModule",
"USAVarsDataModule",
Expand Down
53 changes: 53 additions & 0 deletions torchgeo/datamodules/sustainbench_crop_yield_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Sustainbench Crop Yield Prediction datamodule."""

from typing import Any

from ..datasets import SustainBenchCropYieldPrediction
from .geo import NonGeoDataModule


class SustainbenchCropYieldDataModule(NonGeoDataModule):
"""LightningDataModule for Sustainbench Crop Yield Prediction dataset."""

def __init__(
self, batch_size: int = 32, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new SustainbenchCropYieldDataModule instance.

Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.SustainBenchCropYieldPrediction`.

.. versionadded:: 0.5
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__(
SustainBenchCropYieldPrediction, batch_size, num_workers, **kwargs
)

def setup(self, stage: str) -> None:
"""Set up datasets.

Called at the beginning of fit, validate, test, or predict. During distributed
training, this method is called from every process across all the nodes. Setting
state here is recommended.
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ["fit"]:
self.train_dataset = self.dataset_class( # type: ignore[call-arg]
split="train", **self.kwargs
)
if stage in ["fit", "validate"]:
self.val_dataset = self.dataset_class( # type: ignore[call-arg]
split="dev", **self.kwargs
)
if stage in ["test"]:
self.test_dataset = self.dataset_class( # type: ignore[call-arg]
split="test", **self.kwargs
)
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
time_series_split,
)
from .ssl4eo import SSL4EOS12
from .sustainbench_crop_yield_prediction import SustainBenchCropYieldPrediction
from .ucmerced import UCMerced
from .usavars import USAVars
from .utils import (
Expand Down Expand Up @@ -202,6 +203,7 @@
"SpaceNet6",
"SpaceNet7",
"SSL4EOS12",
"SustainBenchCropYieldPrediction",
"TropicalCyclone",
"UCMerced",
"USAVars",
Expand Down
Loading