Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL4EO-S12: add new dataset/datamodule #1151

Merged
merged 31 commits into from Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
e7afd1f
SSL4EO-S12: add new dataset
adamjstewart Feb 28, 2023
44f2953
Style fixes
adamjstewart Feb 28, 2023
45fc445
100% coverage
adamjstewart Feb 28, 2023
729d377
fix mypy
adamjstewart Feb 28, 2023
a2f39d5
black fixes
adamjstewart Feb 28, 2023
0f0e02d
mypy fix
adamjstewart Mar 1, 2023
1ad0fb6
Convert from db to power
adamjstewart Mar 6, 2023
d2a8049
Don't cast to numpy
adamjstewart Mar 7, 2023
3df7ec1
Remove comments referring to SeCo
adamjstewart Mar 7, 2023
a424a77
SSL4EO: add extraction time
adamjstewart Mar 9, 2023
ec0379a
Add RandomSeasonContrast
adamjstewart Mar 10, 2023
6172fff
Fix axes indexing
adamjstewart Mar 10, 2023
4b0774c
Add datamodule
adamjstewart Mar 17, 2023
7b32aa3
fix tests
adamjstewart Mar 18, 2023
bd5962a
mypy fixes
adamjstewart Mar 18, 2023
0265ccf
fix missing import
adamjstewart Mar 18, 2023
6529150
Fix tests
adamjstewart Mar 18, 2023
8268f6e
isort fix
adamjstewart Mar 18, 2023
05dab7d
Typo fix
adamjstewart Mar 19, 2023
1fd7a2b
s2c: add B10
adamjstewart Mar 27, 2023
80a5c3a
Update test channels
adamjstewart Mar 27, 2023
fe28fe9
S2 plotting was broken
calebrob6 Mar 29, 2023
651347b
Fix plotting
calebrob6 Mar 29, 2023
5cae88d
Merge branch 'main' into datasets/ssl4eo
adamjstewart Mar 29, 2023
b4e4410
Black fix
adamjstewart Mar 29, 2023
c4a5c75
Merge branch 'main' into datasets/ssl4eo
adamjstewart Mar 30, 2023
5dbcab6
Rename conf files
adamjstewart Mar 30, 2023
8af0a53
Remove file introduced by bad merge
adamjstewart Mar 31, 2023
87f068e
Fix pixel size of bands
adamjstewart Apr 9, 2023
d553b57
black fix
adamjstewart Apr 9, 2023
90714a0
Better S1 plotting
adamjstewart Apr 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Expand Up @@ -109,6 +109,11 @@ SpaceNet

.. autoclass:: SpaceNet1DataModule

SSL4EO
^^^^^^

.. autoclass:: SSL4EOS12DataModule

Tropical Cyclone
^^^^^^^^^^^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Expand Up @@ -298,6 +298,11 @@ SpaceNet
.. autoclass:: SpaceNet6
.. autoclass:: SpaceNet7

SSL4EO
^^^^^^

.. autoclass:: SSL4EOS12

Tropical Cyclone
^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Expand Up @@ -28,6 +28,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`SEN12MS`_,S,"Sentinel-1/2, MODIS","180,662",33,256x256,10,"SAR, MSI"
`So2Sat`_,C,Sentinel-1/2,"400,673",17,32x32,10,"SAR, MSI"
`SpaceNet`_,I,WorldView-2/3 Planet Lab Dove,"1,889--28,728",2,102--900,0.5--4,MSI
`SSL4EO`_,T,Sentinel-1/2,1M,-,264x264,10,"SAR, MSI"
`Tropical Cyclone`_,R,GOES 8--16,"108,110",-,256x256,4K--8K,MSI
`UC Merced`_,C,USGS National Map,"21,000",21,256x256,0.3,RGB
`USAVars`_,R,NAIP Aerial,100K,-,-,4,"RGB, NIR"
Expand Down
13 changes: 13 additions & 0 deletions tests/conf/ssl4eo_s12_1.yaml
@@ -0,0 +1,13 @@
experiment:
task: "ssl4eo_s12"
module:
in_channels: 13
backbone: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
datamodule:
root: "tests/data/ssl4eo/s12"
seasons: 1
batch_size: 2
num_workers: 0
13 changes: 13 additions & 0 deletions tests/conf/ssl4eo_s12_2.yaml
@@ -0,0 +1,13 @@
experiment:
task: "ssl4eo_s12"
module:
in_channels: 13
backbone: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
datamodule:
root: "tests/data/ssl4eo/s12"
seasons: 2
batch_size: 2
num_workers: 0
151 changes: 151 additions & 0 deletions tests/data/ssl4eo/s12/data.py
@@ -0,0 +1,151 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil
from typing import Dict, List, Union

import numpy as np
import rasterio
from rasterio import Affine
from rasterio.crs import CRS

SIZE = 36

np.random.seed(0)

FILENAME_HIERARCHY = Union[Dict[str, "FILENAME_HIERARCHY"], List[str]]

s1 = ["VH.tif", "VV.tif"]
s2c = [
"B1.tif",
"B2.tif",
"B3.tif",
"B4.tif",
"B5.tif",
"B6.tif",
"B7.tif",
"B8.tif",
"B8A.tif",
"B9.tif",
"B10.tif",
"B11.tif",
"B12.tif",
]
s2a = s2c.copy()
s2a.remove("B10.tif")
filenames: FILENAME_HIERARCHY = {
"s1": {
"0000000": {
"S1A_IW_GRDH_1SDV_20200329T001515_20200329T001540_031883_03AE27_9BAF": s1,
"S1A_IW_GRDH_1SDV_20201230T001523_20201230T001548_035908_04349D_C91E": s1,
"S1B_IW_GRDH_1SDV_20200627T001449_20200627T001514_022212_02A27E_2A09": s1,
"S1B_IW_GRDH_1SDV_20200928T120105_20200928T120130_023575_02CCB0_F035": s1,
},
"0000001": {
"S1B_IW_GRDH_1SDV_20201101T091054_20201101T091119_024069_02DC0F_F189": s1,
"S1B_IW_GRDH_1SDV_20210205T091050_20210205T091115_025469_0308CB_AA25": s1,
"S1B_IW_GRDH_1SDV_20210430T091051_20210430T091116_026694_03303D_69B6": s1,
"S1B_IW_GRDH_1SDV_20210804T091057_20210804T091122_028094_0359FE_6D9D": s1,
},
},
"s2c": {
"0000000": {
"20200323T162931_20200323T163750_T15QXA": s2c,
"20200621T162901_20200621T164746_T15QXA": s2c,
"20200924T162929_20200924T164434_T15QXA": s2c,
"20201228T163711_20201228T164519_T15QXA": s2c,
},
"0000001": {
"20201104T135121_20201104T135117_T21KXT": s2c,
"20210123T135111_20210123T135113_T21KXT": s2c,
"20210508T135109_20210508T135519_T21KXT": s2c,
"20210811T135121_20210811T135115_T21KXT": s2c,
},
},
"s2a": {
"0000000": {
"20200323T162931_20200323T163750_T15QXA": s2a,
"20200621T162901_20200621T164746_T15QXA": s2a,
"20200924T162929_20200924T164434_T15QXA": s2a,
"20201228T163711_20201228T164519_T15QXA": s2a,
},
"0000001": {
"20201104T135121_20201104T135117_T21KXT": s2a,
"20210123T135111_20210123T135113_T21KXT": s2a,
"20210508T135109_20210508T135519_T21KXT": s2a,
"20210811T135121_20210811T135115_T21KXT": s2a,
},
},
}


def create_file(path: str) -> None:
profile = {
"driver": "GTiff",
"dtype": "uint16",
"width": SIZE,
"height": SIZE,
"count": 1,
"crs": CRS.from_epsg(4326),
"transform": Affine(
9.360247437056711e-05,
0.0,
-91.84615634290395,
0.0,
-8.929489328769368e-05,
18.588542158464236,
),
}

if path.endswith("VH.tif") or path.endswith("VV.tif"):
profile["dtype"] = "float32"

if path.endswith("B1.tif") or path.endswith("B9.tif") or path.endswith("B10.tif"):
profile["width"] = profile["height"] = SIZE // 6
elif (
path.endswith("B5.tif")
or path.endswith("B6.tif")
or path.endswith("B7.tif")
or path.endswith("B8A.tif")
or path.endswith("B11.tif")
or path.endswith("B12.tif")
):
profile["width"] = profile["height"] = SIZE // 2

Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"])

with rasterio.open(path, "w", **profile) as src:
src.write(Z, 1)


def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
if isinstance(hierarchy, dict):
# Recursive case
for key, value in hierarchy.items():
path = os.path.join(directory, key)
os.makedirs(path, exist_ok=True)
create_directory(path, value)
else:
# Base case
for value in hierarchy:
path = os.path.join(directory, value)
create_file(path)


if __name__ == "__main__":
create_directory(".", filenames)

files = ["s1", "s2_l1c", "s2_l2a"]
directories = ["s1", "s2c", "s2a"]
for file, directory in zip(files, directories):
# Create tarballs
shutil.make_archive(file, "gztar", ".", directory)

# Compute checksums
with open(f"{file}.tar.gz", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(file, md5)
Binary file added tests/data/ssl4eo/s12/s1.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/ssl4eo/s12/s2_l1c.tar.gz
Binary file not shown.
Binary file added tests/data/ssl4eo/s12/s2_l2a.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
16 changes: 8 additions & 8 deletions tests/datamodules/test_geo.py
Expand Up @@ -97,28 +97,28 @@ def test_setup(self, stage: str) -> None:

def test_train(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("fit")
datamodule.trainer.training = True # type: ignore[union-attr]
datamodule.trainer.training = True
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_val(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("validate")
datamodule.trainer.validating = True # type: ignore[union-attr]
datamodule.trainer.validating = True
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_test(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("test")
datamodule.trainer.testing = True # type: ignore[union-attr]
datamodule.trainer.testing = True
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_predict(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("predict")
datamodule.trainer.predicting = True # type: ignore[union-attr]
datamodule.trainer.predicting = True
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)
Expand Down Expand Up @@ -156,25 +156,25 @@ def test_setup(self, stage: str) -> None:

def test_train(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("fit")
datamodule.trainer.training = True # type: ignore[union-attr]
datamodule.trainer.training = True
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_val(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("validate")
datamodule.trainer.validating = True # type: ignore[union-attr]
datamodule.trainer.validating = True
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_test(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("test")
datamodule.trainer.testing = True # type: ignore[union-attr]
datamodule.trainer.testing = True
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_predict(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("predict")
datamodule.trainer.predicting = True # type: ignore[union-attr]
datamodule.trainer.predicting = True
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

Expand Down
74 changes: 74 additions & 0 deletions tests/datasets/test_ssl4eo.py
@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path
from typing import cast

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset

from torchgeo.datasets import SSL4EOS12


class TestSSL4EOS12:
@pytest.fixture(params=zip(SSL4EOS12.metadata.keys(), [1, 1, 2]))
def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> SSL4EOS12:
monkeypatch.setitem(
SSL4EOS12.metadata["s1"], "md5", "1661fd407a49a0fbe8c6a5073734a731"
)
monkeypatch.setitem(
SSL4EOS12.metadata["s2c"], "md5", "4946a093ea88db0f75955be318901b82"
)
monkeypatch.setitem(
SSL4EOS12.metadata["s2a"], "md5", "36944718ff658b65ca4d8724918500ac"
)

root = os.path.join("tests", "data", "ssl4eo", "s12")
split, seasons = request.param
transforms = nn.Identity()
return SSL4EOS12(root, split, seasons, transforms, checksum=True)

def test_getitem(self, dataset: SSL4EOS12) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].size(0) == dataset.seasons * len(dataset.bands)

def test_len(self, dataset: SSL4EOS12) -> None:
assert len(dataset) == 251079

def test_add(self, dataset: SSL4EOS12) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 2 * 251079

def test_extract(self, tmp_path: Path) -> None:
for split in SSL4EOS12.metadata:
filename = cast(str, SSL4EOS12.metadata[split]["filename"])
shutil.copyfile(
os.path.join("tests", "data", "ssl4eo", "s12", filename),
tmp_path / filename,
)
SSL4EOS12(str(tmp_path))

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
SSL4EOS12(split="foo")

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
SSL4EOS12(str(tmp_path))

def test_plot(self, dataset: SSL4EOS12) -> None:
sample = dataset[0]
dataset.plot(sample, suptitle="Test")
plt.close()
dataset.plot(sample, show_titles=False)
plt.close()
13 changes: 11 additions & 2 deletions tests/trainers/test_byol.py
Expand Up @@ -17,8 +17,12 @@
from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import ChesapeakeCVPRDataModule, SeasonalContrastS2DataModule
from torchgeo.datasets import SeasonalContrastS2
from torchgeo.datamodules import (
ChesapeakeCVPRDataModule,
SeasonalContrastS2DataModule,
SSL4EOS12DataModule,
)
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
Expand Down Expand Up @@ -55,6 +59,8 @@ class TestBYOLTask:
("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule),
("seco_1", SeasonalContrastS2DataModule),
("seco_2", SeasonalContrastS2DataModule),
("ssl4eo_s12_1", SSL4EOS12DataModule),
("ssl4eo_s12_2", SSL4EOS12DataModule),
],
)
def test_trainer(
Expand All @@ -71,6 +77,9 @@ def test_trainer(
if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)

if name.startswith("ssl4eo_s12"):
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2)

# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule = classname(**datamodule_kwargs)
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Expand Up @@ -25,6 +25,7 @@
from .sen12ms import SEN12MSDataModule
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
from .ssl4eo import SSL4EOS12DataModule
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
from .utils import MisconfigurationException
Expand Down Expand Up @@ -55,6 +56,7 @@
"SEN12MSDataModule",
"So2SatDataModule",
"SpaceNet1DataModule",
"SSL4EOS12DataModule",
"TropicalCycloneDataModule",
"UCMercedDataModule",
"USAVarsDataModule",
Expand Down