Skip to content

Commit

Permalink
Add datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Mar 17, 2023
1 parent 6172fff commit 4b0774c
Show file tree
Hide file tree
Showing 115 changed files with 79 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Expand Up @@ -109,6 +109,11 @@ SpaceNet

.. autoclass:: SpaceNet1DataModule

SSL4EO
^^^^^^

.. autoclass:: SSL4EOS12DataModule

Tropical Cyclone
^^^^^^^^^^^^^^^^

Expand Down
13 changes: 13 additions & 0 deletions tests/conf/ssl4eo_s12_1.yaml
@@ -0,0 +1,13 @@
experiment:
task: "ssl4eo_s12"
module:
in_channels: 3
backbone: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
datamodule:
root: "tests/data/ssl4eo_s12"
seasons: 1
batch_size: 2
num_workers: 0
13 changes: 13 additions & 0 deletions tests/conf/ssl4eo_s12_2.yaml
@@ -0,0 +1,13 @@
experiment:
task: "ssl4eo_s12"
module:
in_channels: 3
backbone: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
datamodule:
root: "tests/data/ssl4eo_s12"
seasons: 2
batch_size: 2
num_workers: 0
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions tests/datasets/test_ssl4eo.py
Expand Up @@ -30,7 +30,7 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> SSL4EOS12:
SSL4EOS12.metadata["s2a"], "md5", "bc0bc2e5e0ad93a510330b90cd157c95"
)

root = os.path.join("tests", "data", "ssl4eo")
root = os.path.join("tests", "data", "ssl4eo_s12")
split, seasons = request.param
transforms = nn.Identity()
return SSL4EOS12(root, split, seasons, transforms, checksum=True)
Expand All @@ -53,7 +53,7 @@ def test_extract(self, tmp_path: Path) -> None:
for split in SSL4EOS12.metadata:
filename = cast(str, SSL4EOS12.metadata[split]["filename"])
shutil.copyfile(
os.path.join("tests", "data", "ssl4eo", filename), tmp_path / filename
os.path.join("tests", "data", "ssl4eo_s12", filename), tmp_path / filename
)
SSL4EOS12(str(tmp_path))

Expand Down
4 changes: 3 additions & 1 deletion tests/trainers/test_byol.py
Expand Up @@ -17,7 +17,7 @@
from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import ChesapeakeCVPRDataModule, SeasonalContrastS2DataModule
from torchgeo.datamodules import ChesapeakeCVPRDataModule, SeasonalContrastS2DataModule, SSL4EOS12DataModule
from torchgeo.datasets import SeasonalContrastS2
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import BYOLTask
Expand Down Expand Up @@ -55,6 +55,8 @@ class TestBYOLTask:
("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule),
("seco_1", SeasonalContrastS2DataModule),
("seco_2", SeasonalContrastS2DataModule),
("ssl4eo_s12_1", SSL4EOS12DataModule),
("ssl4eo_s12_2", SSL4EOS12DataModule),
],
)
def test_trainer(
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Expand Up @@ -25,6 +25,7 @@
from .sen12ms import SEN12MSDataModule
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
from .ssl4eo import SSL4EOS12DataModule
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
from .utils import MisconfigurationException
Expand Down Expand Up @@ -55,6 +56,7 @@
"SEN12MSDataModule",
"So2SatDataModule",
"SpaceNet1DataModule",
"SSL4EOS12DataModule",
"TropicalCycloneDataModule",
"UCMercedDataModule",
"USAVarsDataModule",
Expand Down
41 changes: 41 additions & 0 deletions torchgeo/datamodules/ssl4eo.py
@@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""SSL4EO datamodule."""

from typing import Any

from ..datasets import SSL4EOS12
from .geo import NonGeoDataModule


class SSL4EOS12DataModule(NonGeoDataModule):
"""LightningDataModule implementation for the SSL4EO-S12 dataset.
.. versionadded:: 0.5
"""

# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
mean = 0.0
std = 10000.0

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new SSL4EOS12DataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.SSL4EOS12`.
"""
super().__init__(SSL4EOS12, batch_size, num_workers, **kwargs)

def setup(self, stage: str) -> None:
"""Set up datasets.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.dataset = SSL4EOS12(**self.kwargs)

0 comments on commit 4b0774c

Please sign in to comment.