From d3c82a5ac8f9e6c471ca2ad6e88a3f6308d35d44 Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Sun, 23 Apr 2023 17:37:29 -0500 Subject: [PATCH] Refactor train.py (#1237) * add pretrain.py tested with seco100k * refactor * add pretrain.py tested with seco100k * refactor * refactor to train.py and add simclr * revert simclr and pretrain.py changes * revert more simclr changesg * add refactor to configs and train.py * add hydra.utils.instantiate import * fix flake8 * update tests and add hydra-core to deps * fix byol tests * update exp name * format * remove evaluate.py * add hydra-core to min tests deps * update tests * add trainer to configs and use lightning.Trainer insteead of pl.Trainer * add eurosat100 test * update train.py * lightning.Trainer -> lightning.pytorch.Trainer * remove omegaconf * update hydra-core to 1.1.1 and fix mypy * add recursive hydra config test * update coment * update test config * fix tests * add omegaconf back in * remove hydra recursive test * update hydra-core to 2.3.0 for ci * Try older hydra * Older hydra requires old omegaconf * Try older hydra * Try newer hydra * Try older hydra * Try newer hydra * Finalize minimum dep versions --------- Co-authored-by: Adam J. Stewart --- .pre-commit-config.yaml | 2 +- conf/bigearthnet.yaml | 40 +-- conf/byol.yaml | 26 -- conf/chesapeake_cvpr.yaml | 61 ++--- conf/cowc_counting.yaml | 31 ++- conf/cyclone.yaml | 31 ++- conf/deepglobelandcover.yaml | 48 ++-- conf/defaults.yaml | 2 +- conf/etci2021.yaml | 40 +-- conf/eurosat.yaml | 36 ++- conf/gid15.yaml | 48 ++-- conf/inria.yaml | 46 ++-- conf/landcoverai.yaml | 42 +-- conf/naipchesapeake.yaml | 47 ++-- conf/nasa_marine_debris.yaml | 40 ++- conf/potsdam2d.yaml | 48 ++-- conf/resisc45.yaml | 36 +-- conf/seco_100k.yaml | 24 ++ conf/sen12ms.yaml | 43 +-- conf/so2sat.yaml | 38 +-- conf/spacenet1.yaml | 45 ++-- conf/ucmerced.yaml | 36 ++- conf/vaihingen2d.yaml | 48 ++-- environment.yml | 3 +- evaluate.py | 294 --------------------- requirements/min-reqs.old | 3 +- requirements/tests.txt | 1 + setup.cfg | 6 +- tests/conf/bigearthnet_all.yaml | 35 +-- tests/conf/bigearthnet_s1.yaml | 35 +-- tests/conf/bigearthnet_s2.yaml | 35 +-- tests/conf/chesapeake_cvpr_5.yaml | 55 ++-- tests/conf/chesapeake_cvpr_7.yaml | 55 ++-- tests/conf/chesapeake_cvpr_prior_byol.yaml | 50 ++-- tests/conf/cowc_counting.yaml | 29 +- tests/conf/cyclone.yaml | 29 +- tests/conf/deepglobelandcover.yaml | 41 +-- tests/conf/etci2021.yaml | 35 +-- tests/conf/eurosat.yaml | 31 +-- tests/conf/eurosat100.yaml | 16 ++ tests/conf/fire_risk.yaml | 31 +-- tests/conf/gid15.yaml | 43 +-- tests/conf/inria.yaml | 39 +-- tests/conf/l7irish.yaml | 43 +-- tests/conf/l8biome.yaml | 43 +-- tests/conf/landcoverai.yaml | 39 +-- tests/conf/loveda.yaml | 39 +-- tests/conf/naipchesapeake.yaml | 41 +-- tests/conf/nasa_marine_debris.yaml | 29 +- tests/conf/potsdam2d.yaml | 41 +-- tests/conf/resisc45.yaml | 31 +-- tests/conf/seco_byol_1.yaml | 27 +- tests/conf/seco_byol_2.yaml | 27 +- tests/conf/sen12ms_all.yaml | 35 +-- tests/conf/sen12ms_s1.yaml | 37 +-- tests/conf/sen12ms_s2_all.yaml | 35 +-- tests/conf/sen12ms_s2_reduced.yaml | 35 +-- tests/conf/skippd.yaml | 29 +- tests/conf/so2sat_all.yaml | 31 +-- tests/conf/so2sat_s1.yaml | 31 +-- tests/conf/so2sat_s2.yaml | 31 +-- tests/conf/spacenet1.yaml | 43 +-- tests/conf/ssl4eo_s12_byol_1.yaml | 27 +- tests/conf/ssl4eo_s12_byol_2.yaml | 27 +- tests/conf/sustainbench_crop_yield.yaml | 29 +- tests/conf/ucmerced.yaml | 31 +-- tests/conf/vaihingen2d.yaml | 41 +-- tests/trainers/test_byol.py | 39 +-- tests/trainers/test_classification.py | 67 ++--- tests/trainers/test_detection.py | 26 +- tests/trainers/test_regression.py | 36 +-- tests/trainers/test_segmentation.py | 77 ++---- train.py | 144 +++------- 73 files changed, 1292 insertions(+), 1643 deletions(-) delete mode 100644 conf/byol.yaml create mode 100644 conf/seco_100k.yaml delete mode 100755 evaluate.py create mode 100644 tests/conf/eurosat100.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c9bbd99ab4a..299079e1f00 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,5 +34,5 @@ repos: hooks: - id: mypy args: [--strict, --ignore-missing-imports, --show-error-codes] - additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=1.8, pytest>=6, pyvista>=0.20, omegaconf>=2.1, kornia>=0.6, numpy>=1.22.0] + additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=1.8, pytest>=6, pyvista>=0.20, omegaconf>=2.0.1, hydra-core>=1, kornia>=0.6, numpy>=1.22] exclude: (build|data|dist|logo|logs|output)/ diff --git a/conf/bigearthnet.yaml b/conf/bigearthnet.yaml index 131728d61d9..2c8a4da5218 100644 --- a/conf/bigearthnet.yaml +++ b/conf/bigearthnet.yaml @@ -1,22 +1,24 @@ +module: + _target_: torchgeo.trainers.MultiLabelClassificationTask + loss: "bce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 14 + num_classes: 19 + +datamodule: + _target_: torchgeo.datamodules.BigEarthNetDataModule + root: "data/bigearthnet" + bands: "all" + num_classes: ${module.num_classes} + batch_size: 128 + num_workers: 4 + trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 - min_epochs: 10 - max_epochs: 40 - benchmark: True -experiment: - task: "bigearthnet" - module: - loss: "bce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 14 - num_classes: 19 - datamodule: - root: "data/bigearthnet" - bands: "all" - num_classes: ${experiment.module.num_classes} - batch_size: 128 - num_workers: 4 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/byol.yaml b/conf/byol.yaml deleted file mode 100644 index b1a7ac927a8..00000000000 --- a/conf/byol.yaml +++ /dev/null @@ -1,26 +0,0 @@ -trainer: - accelerator: gpu - devices: 1 - min_epochs: 20 - max_epochs: 100 - benchmark: True -experiment: - task: "ssl" - name: "test_byol" - module: - model: "byol" - backbone: "resnet18" - input_channels: 4 - weights: imagenet - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - datamodule: - root: "data/chesapeake/cvpr" - train_splits: - - "de-train" - val_splits: - - "de-val" - test_splits: - - "de-test" - batch_size: 64 - num_workers: 4 diff --git a/conf/chesapeake_cvpr.yaml b/conf/chesapeake_cvpr.yaml index bacf56358db..d30f555187e 100644 --- a/conf/chesapeake_cvpr.yaml +++ b/conf/chesapeake_cvpr.yaml @@ -1,33 +1,34 @@ +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 4 + num_classes: 7 + num_filters: 256 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule + root: "data/chesapeake/cvpr" + train_splits: + - "de-train" + val_splits: + - "de-val" + test_splits: + - "de-test" + batch_size: 200 + patch_size: 256 + num_workers: 4 + class_set: ${module.num_classes} + use_prior_labels: False + trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 - min_epochs: 20 - max_epochs: 100 - benchmark: True -experiment: - task: "chesapeake_cvpr" - name: "chesapeake_cvpr_example" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 4 - num_classes: 7 - num_filters: 256 - ignore_index: null - datamodule: - root: "data/chesapeake/cvpr" - train_splits: - - "de-train" - val_splits: - - "de-val" - test_splits: - - "de-test" - batch_size: 200 - patch_size: 256 - num_workers: 4 - class_set: ${experiment.module.num_classes} - use_prior_labels: False + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/cowc_counting.yaml b/conf/cowc_counting.yaml index 334672d159a..787bb55b835 100644 --- a/conf/cowc_counting.yaml +++ b/conf/cowc_counting.yaml @@ -1,18 +1,21 @@ +module: + _target_: torchgeo.trainers.RegressionTask + model: resnet18 + weights: null + num_outputs: 1 + in_channels: 3 + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + +datamodule: + _target_: torchgeo.datamodules.COWCCountingDataModule + root: "data/cowc_counting" + batch_size: 64 + num_workers: 4 + trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 min_epochs: 15 -experiment: - task: cowc_counting - name: cowc_counting_test - module: - model: resnet18 - weights: null - num_outputs: 1 - in_channels: 3 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - datamodule: - root: "data/cowc_counting" - batch_size: 64 - num_workers: 4 + max_epochs: 40 \ No newline at end of file diff --git a/conf/cyclone.yaml b/conf/cyclone.yaml index 733a48885e7..f5200ae8d74 100644 --- a/conf/cyclone.yaml +++ b/conf/cyclone.yaml @@ -1,18 +1,21 @@ +module: + _target_: torchgeo.trainers.RegressionTask + model: "resnet18" + weights: null + num_outputs: 1 + in_channels: 3 + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + +datamodule: + _target_: torchgeo.datamodules.TropicalCycloneDataModule + root: "data/cyclone" + batch_size: 32 + num_workers: 4 + trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 min_epochs: 15 -experiment: - task: "cyclone" - name: "cyclone_test" - module: - model: "resnet18" - weights: null - num_outputs: 1 - in_channels: 3 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - datamodule: - root: "data/cyclone" - batch_size: 32 - num_workers: 4 + max_epochs: 40 \ No newline at end of file diff --git a/conf/deepglobelandcover.yaml b/conf/deepglobelandcover.yaml index 2e09eca0e4b..9c7da9adbb4 100644 --- a/conf/deepglobelandcover.yaml +++ b/conf/deepglobelandcover.yaml @@ -1,20 +1,28 @@ -experiment: - task: "deepglobelandcover" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 3 - num_classes: 7 - num_filters: 1 - ignore_index: null - datamodule: - root: "data/deepglobelandcover" - batch_size: 1 - patch_size: 64 - val_split_pct: 0.5 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.DeepGlobeLandCoverDataModule + root: "data/deepglobelandcover" + batch_size: 1 + patch_size: 64 + val_split_pct: 0.5 + num_workers: 0 + +trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: 1 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/defaults.yaml b/conf/defaults.yaml index 15d58be2656..adcdf816e43 100644 --- a/conf/defaults.yaml +++ b/conf/defaults.yaml @@ -5,4 +5,4 @@ program: # These are the arguments that define how the train.py script works output_dir: output data_dir: data log_dir: logs - overwrite: False + overwrite: False \ No newline at end of file diff --git a/conf/etci2021.yaml b/conf/etci2021.yaml index c6e02005d10..2550f7e234d 100644 --- a/conf/etci2021.yaml +++ b/conf/etci2021.yaml @@ -1,16 +1,24 @@ -experiment: - task: "etci2021" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: "imagenet" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 6 - num_classes: 2 - ignore_index: 0 - datamodule: - root: "data/etci2021" - batch_size: 32 - num_workers: 4 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: "imagenet" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 6 + num_classes: 2 + ignore_index: 0 + +datamodule: + _target_: torchgeo.datamodules.ETCI2021DataModule + root: "data/etci2021" + batch_size: 32 + num_workers: 4 + +trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: 1 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/eurosat.yaml b/conf/eurosat.yaml index 89dddfd1941..6f744a04cda 100644 --- a/conf/eurosat.yaml +++ b/conf/eurosat.yaml @@ -1,14 +1,22 @@ -experiment: - task: "eurosat" - module: - loss: "ce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 13 - num_classes: 10 - datamodule: - root: "data/eurosat" - batch_size: 128 - num_workers: 4 +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "ce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 13 + num_classes: 10 + +datamodule: + _target_: torchgeo.datamodules.EuroSATDataModule + root: "data/eurosat" + batch_size: 128 + num_workers: 4 + +trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: 1 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/gid15.yaml b/conf/gid15.yaml index 420c6b2f0e9..2f21fc94195 100644 --- a/conf/gid15.yaml +++ b/conf/gid15.yaml @@ -1,20 +1,28 @@ -experiment: - task: "gid15" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 3 - num_classes: 16 - num_filters: 1 - ignore_index: null - datamodule: - root: "data/gid15" - batch_size: 1 - patch_size: 64 - val_split_pct: 0.5 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 16 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.GID15DataModule + root: "data/gid15" + batch_size: 1 + patch_size: 64 + val_split_pct: 0.5 + num_workers: 0 + +trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: 1 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/inria.yaml b/conf/inria.yaml index a269f0bd5e9..e0f716e292b 100644 --- a/conf/inria.yaml +++ b/conf/inria.yaml @@ -1,29 +1,25 @@ -program: - overwrite: True +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: "imagenet" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 2 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.InriaAerialImageLabelingDataModule + root: "data/inria" + batch_size: 1 + patch_size: 512 + num_workers: 32 trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 - min_epochs: 5 - max_epochs: 100 - benchmark: True - log_every_n_steps: 2 - -experiment: - task: "inria" - name: "inria_test" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: "imagenet" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 3 - num_classes: 2 - ignore_index: null - datamodule: - root: "data/inria" - batch_size: 1 - patch_size: 512 - num_workers: 32 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/landcoverai.yaml b/conf/landcoverai.yaml index ef5261abdee..0136527a19a 100644 --- a/conf/landcoverai.yaml +++ b/conf/landcoverai.yaml @@ -1,23 +1,25 @@ +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: "imagenet" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 5 + num_filters: 256 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.LandCoverAIDataModule + root: "data/landcoverai" + batch_size: 32 + num_workers: 4 + trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 - min_epochs: 20 - max_epochs: 100 - benchmark: True -experiment: - task: "landcoverai" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: "imagenet" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 3 - num_classes: 5 - num_filters: 256 - ignore_index: null - datamodule: - root: "data/landcoverai" - batch_size: 32 - num_workers: 4 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/naipchesapeake.yaml b/conf/naipchesapeake.yaml index 709224eca9d..ede6db4e336 100644 --- a/conf/naipchesapeake.yaml +++ b/conf/naipchesapeake.yaml @@ -1,24 +1,27 @@ -program: - experiment_name: "naip_chesapeake_test" - overwrite: True +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "deeplabv3+" + backbone: "resnet34" + weights: "imagenet" + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + in_channels: 4 + num_classes: 14 + num_filters: 64 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.NAIPChesapeakeDataModule + naip_root: "data/naip" + chesapeake_root: "data/chesapeake/BAYWIDE" + batch_size: 32 + num_workers: 4 + patch_size: 32 + trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: 1 min_epochs: 15 -experiment: - task: "naipchesapeake" - module: - loss: "ce" - model: "deeplabv3+" - backbone: "resnet34" - weights: "imagenet" - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - in_channels: 4 - num_classes: 14 - num_filters: 64 - ignore_index: null - datamodule: - naip_root: "data/naip" - chesapeake_root: "data/chesapeake/BAYWIDE" - batch_size: 32 - num_workers: 4 - patch_size: 32 + max_epochs: 40 \ No newline at end of file diff --git a/conf/nasa_marine_debris.yaml b/conf/nasa_marine_debris.yaml index 4908400bec1..89164a63c74 100644 --- a/conf/nasa_marine_debris.yaml +++ b/conf/nasa_marine_debris.yaml @@ -1,26 +1,22 @@ -program: - overwrite: True +module: + _target_: torchgeo.trainers.ObjectDetectionTask + model: "faster-rcnn" + backbone: "resnet50" + num_classes: 2 + learning_rate: 1.2e-4 + learning_rate_schedule_patience: 6 + verbose: false + +datamodule: + _target_: torchgeo.datamodules.NASAMarineDebrisDataModule + root: "data/nasamr/nasa_marine_debris" + batch_size: 4 + num_workers: 6 + val_split_pct: 0.2 trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 - min_epochs: 5 - max_epochs: 100 - auto_lr_find: False - benchmark: True - -experiment: - task: "nasa_marine_debris" - name: "nasa_marine_debris_test" - module: - model: "faster-rcnn" - backbone: "resnet50" - num_classes: 2 - learning_rate: 1.2e-4 - learning_rate_schedule_patience: 6 - verbose: false - datamodule: - root: "data/nasamr/nasa_marine_debris" - batch_size: 4 - num_workers: 6 - val_split_pct: 0.2 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/potsdam2d.yaml b/conf/potsdam2d.yaml index e1312fa57d4..076a1d75f72 100644 --- a/conf/potsdam2d.yaml +++ b/conf/potsdam2d.yaml @@ -1,20 +1,28 @@ -experiment: - task: "potsdam2d" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 4 - num_classes: 6 - num_filters: 1 - ignore_index: null - datamodule: - root: "data/potsdam" - batch_size: 1 - patch_size: 64 - val_split_pct: 0.5 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 4 + num_classes: 6 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.Potsdam2DDataModule + root: "data/potsdam" + batch_size: 1 + patch_size: 64 + val_split_pct: 0.5 + num_workers: 0 + +trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: 1 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/resisc45.yaml b/conf/resisc45.yaml index ad57f856d5d..05978aa5e84 100644 --- a/conf/resisc45.yaml +++ b/conf/resisc45.yaml @@ -1,20 +1,22 @@ +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "ce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 3 + num_classes: 45 + +datamodule: + _target_: torchgeo.datamodules.RESISC45DataModule + root: "data/resisc45" + batch_size: 128 + num_workers: 4 + trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 - min_epochs: 10 - max_epochs: 40 - benchmark: True -experiment: - task: "resisc45" - module: - loss: "ce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 3 - num_classes: 45 - datamodule: - root: "data/resisc45" - batch_size: 128 - num_workers: 4 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/seco_100k.yaml b/conf/seco_100k.yaml new file mode 100644 index 00000000000..e9d83fa4e87 --- /dev/null +++ b/conf/seco_100k.yaml @@ -0,0 +1,24 @@ +module: + _target_: torchgeo.trainers.BYOLTask + in_channels: 12 + backbone: "resnet18" + weights: True + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + optimizer: "Adam" + +datamodule: + _target_: torchgeo.datamodules.SeasonalContrastS2DataModule + root: "data/seco" + version: "100k" + seasons: 2 + bands: ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"] + batch_size: 64 + num_workers: 16 + +trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: 1 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/sen12ms.yaml b/conf/sen12ms.yaml index 3946774328a..553d5c996e8 100644 --- a/conf/sen12ms.yaml +++ b/conf/sen12ms.yaml @@ -1,22 +1,25 @@ -program: - experiment_name: sen12ms_test - overwrite: True +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + in_channels: 15 + num_classes: 11 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.SEN12MSDataModule + root: "data/sen12ms" + band_set: "all" + batch_size: 32 + num_workers: 4 + trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: 1 min_epochs: 15 -experiment: - task: "sen12ms" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - in_channels: 15 - num_classes: 11 - ignore_index: null - datamodule: - root: "data/sen12ms" - band_set: "all" - batch_size: 32 - num_workers: 4 + max_epochs: 40 \ No newline at end of file diff --git a/conf/so2sat.yaml b/conf/so2sat.yaml index f515622e0d8..b54025dfe51 100644 --- a/conf/so2sat.yaml +++ b/conf/so2sat.yaml @@ -1,21 +1,23 @@ +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "ce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 18 + num_classes: 17 + +datamodule: + _target_: torchgeo.datamodules.So2SatDataModule + root: "data/so2sat" + batch_size: 128 + num_workers: 4 + band_set: "all" + trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 - min_epochs: 10 - max_epochs: 40 - benchmark: True -experiment: - task: "so2sat" - module: - loss: "ce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 18 - num_classes: 17 - datamodule: - root: "data/so2sat" - batch_size: 128 - num_workers: 4 - band_set: "all" + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/spacenet1.yaml b/conf/spacenet1.yaml index 5162c70c9d4..3bfd735680d 100644 --- a/conf/spacenet1.yaml +++ b/conf/spacenet1.yaml @@ -1,25 +1,24 @@ -program: - overwrite: False +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: "imagenet" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 3 + ignore_index: 0 + +datamodule: + _target_: torchgeo.datamodules.SpaceNet1DataModule + root: "data/spacenet" + batch_size: 32 + num_workers: 4 + trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu - devices: 3 - min_epochs: 50 - max_epochs: 200 - benchmark: True -experiment: - name: "spacenet-example" - task: "sen12ms" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: "imagenet" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 3 - num_classes: 3 - ignore_index: 0 - datamodule: - root: "data/spacenet" - batch_size: 32 - num_workers: 4 + devices: 1 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/ucmerced.yaml b/conf/ucmerced.yaml index 4ab6612d1ae..bae2aab676e 100644 --- a/conf/ucmerced.yaml +++ b/conf/ucmerced.yaml @@ -1,14 +1,22 @@ -experiment: - task: "ucmerced" - module: - loss: "ce" - model: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 3 - num_classes: 21 - datamodule: - root: "data/ucmerced" - batch_size: 128 - num_workers: 4 +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "ce" + model: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 21 + +datamodule: + _target_: torchgeo.datamodules.UCMercedDataModule + root: "data/ucmerced" + batch_size: 128 + num_workers: 4 + +trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: 1 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/conf/vaihingen2d.yaml b/conf/vaihingen2d.yaml index c6fd448c6dd..db6248b052f 100644 --- a/conf/vaihingen2d.yaml +++ b/conf/vaihingen2d.yaml @@ -1,20 +1,28 @@ -experiment: - task: "vaihingen2d" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 3 - num_classes: 7 - num_filters: 1 - ignore_index: null - datamodule: - root: "data/vaihingen" - batch_size: 1 - patch_size: 64 - val_split_pct: 0.5 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.Vaihingen2DDataModule + root: "data/vaihingen" + batch_size: 1 + patch_size: 64 + val_split_pct: 0.5 + num_workers: 0 + +trainer: + _target_: lightning.pytorch.Trainer + accelerator: gpu + devices: 1 + min_epochs: 15 + max_epochs: 40 \ No newline at end of file diff --git a/environment.yml b/environment.yml index 1f289475f77..51ec988b274 100644 --- a/environment.yml +++ b/environment.yml @@ -21,6 +21,7 @@ dependencies: - pip: - black[jupyter]>=21.8 - flake8>=3.8 + - hydra-core>=1 - ipywidgets>=7 - isort[colors]>=5.8 - kornia>=0.6.5 @@ -29,7 +30,7 @@ dependencies: - mypy>=0.900 - nbmake>=1.3.3 - nbsphinx>=0.8.5 - - omegaconf>=2.1 + - omegaconf>=2.0.1 - opencv-python>=4.4.0.46 - pandas>=1.1.3 - pillow>=8 diff --git a/evaluate.py b/evaluate.py deleted file mode 100755 index beab7c25a8f..00000000000 --- a/evaluate.py +++ /dev/null @@ -1,294 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""torchgeo model evaluation script.""" - -import argparse -import csv -import os -from typing import Any, Union, cast - -import lightning.pytorch as pl -import torch -from torchmetrics import MetricCollection -from torchmetrics.classification import BinaryAccuracy, BinaryJaccardIndex - -from torchgeo.trainers import ( - ClassificationTask, - ObjectDetectionTask, - SemanticSegmentationTask, -) -from train import TASK_TO_MODULES_MAPPING - - -def set_up_parser() -> argparse.ArgumentParser: - """Set up the argument parser. - - Returns: - the argument parser - """ - parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--task", - choices=TASK_TO_MODULES_MAPPING.keys(), - type=str, - help="name of task to test", - ) - parser.add_argument( - "--input-checkpoint", - required=True, - help="path to the checkpoint file to test", - metavar="CKPT", - ) - parser.add_argument( - "--gpu", default=0, type=int, help="GPU ID to use", metavar="ID" - ) - parser.add_argument( - "--root", - required=True, - type=str, - help="root directory of the dataset for the accompanying task", - ) - parser.add_argument( - "-b", - "--batch-size", - default=2**4, - type=int, - help="number of samples in each mini-batch", - metavar="SIZE", - ) - parser.add_argument( - "-w", - "--num-workers", - default=6, - type=int, - help="number of workers for parallel data loading", - metavar="NUM", - ) - parser.add_argument( - "--seed", default=0, type=int, help="random seed for reproducibility" - ) - parser.add_argument( - "--output-fn", - required=True, - type=str, - help="path to the CSV file to write results", - metavar="FILE", - ) - parser.add_argument( - "-v", "--verbose", action="store_true", help="print results to stdout" - ) - - return parser - - -def run_eval_loop( - model: pl.LightningModule, - dataloader: Any, - device: torch.device, - metrics: MetricCollection, -) -> Any: - """Runs a standard test loop over a dataloader and records metrics. - - Args: - model: the model used for inference - dataloader: the dataloader to get samples from - device: the device to put data on - metrics: a torchmetrics compatible metric collection to score the output - from the model - - Returns: - the result of ``metrics.compute()`` - """ - for batch in dataloader: - x = batch["image"].to(device) - if "mask" in batch: - y = batch["mask"].to(device) - elif "label" in batch: - y = batch["label"].to(device) - elif "boxes" in batch: - y = [ - { - "boxes": batch["boxes"][i].to(device), - "labels": batch["labels"][i].to(device), - } - for i in range(len(batch["image"])) - ] - with torch.inference_mode(): - y_pred = model(x) - metrics(y_pred, y) - results = metrics.compute() - metrics.reset() - return results - - -def main(args: argparse.Namespace) -> None: - """High-level pipeline. - - Runs a model checkpoint on a test set and saves results to file. - - Args: - args: command-line arguments - """ - assert os.path.exists(args.input_checkpoint) - assert os.path.exists(args.root) - TASK = TASK_TO_MODULES_MAPPING[args.task][0] - DATAMODULE = TASK_TO_MODULES_MAPPING[args.task][1] - - # Loads the saved model from checkpoint based on the `args.task` name that was - # passed as input - model = TASK.load_from_checkpoint(args.input_checkpoint) - model.freeze() - model.eval() - - dm = DATAMODULE( - seed=args.seed, - root=args.root, - num_workers=args.num_workers, - batch_size=args.batch_size, - ) - dm.setup("validate") - - # Record model hyperparameters - hparams = cast(dict[str, Union[str, float]], model.hparams) - if issubclass(TASK, ClassificationTask): - val_row = { - "split": "val", - "model": hparams["model"], - "learning_rate": hparams["learning_rate"], - "weights": hparams["weights"], - "loss": hparams["loss"], - } - - test_row = { - "split": "test", - "model": hparams["model"], - "learning_rate": hparams["learning_rate"], - "weights": hparams["weights"], - "loss": hparams["loss"], - } - elif issubclass(TASK, SemanticSegmentationTask): - val_row = { - "split": "val", - "model": hparams["model"], - "backbone": hparams["backbone"], - "weights": hparams["weights"], - "learning_rate": hparams["learning_rate"], - "loss": hparams["loss"], - } - - test_row = { - "split": "test", - "model": hparams["model"], - "backbone": hparams["backbone"], - "weights": hparams["weights"], - "learning_rate": hparams["learning_rate"], - "loss": hparams["loss"], - } - elif issubclass(TASK, ObjectDetectionTask): - val_row = { - "split": "val", - "detection_model": hparams["detection_model"], - "backbone": hparams["backbone"], - "learning_rate": hparams["learning_rate"], - } - - test_row = { - "split": "test", - "detection_model": hparams["detection_model"], - "backbone": hparams["backbone"], - "learning_rate": hparams["learning_rate"], - } - else: - raise ValueError(f"{TASK} is not supported") - - # Compute metrics - device = torch.device("cuda:%d" % (args.gpu)) - model = model.to(device) - - if args.task == "etci2021": # Custom metric setup for testing ETCI2021 - metrics = MetricCollection([BinaryAccuracy(), BinaryJaccardIndex()]).to(device) - - val_results = run_eval_loop(model, dm.val_dataloader(), device, metrics) - test_results = run_eval_loop(model, dm.test_dataloader(), device, metrics) - - val_row.update( - { - "overall_accuracy": val_results["Accuracy"].item(), - "jaccard_index": val_results["JaccardIndex"][1].item(), - } - ) - test_row.update( - { - "overall_accuracy": test_results["Accuracy"].item(), - "jaccard_index": test_results["JaccardIndex"][1].item(), - } - ) - else: # Test with PyTorch Lightning as usual - model.val_metrics = cast(MetricCollection, model.val_metrics) - model.test_metrics = cast(MetricCollection, model.test_metrics) - - val_results = run_eval_loop( - model, dm.val_dataloader(), device, model.val_metrics - ) - test_results = run_eval_loop( - model, dm.test_dataloader(), device, model.test_metrics - ) - - # Save the results and model hyperparameters to a CSV file - if issubclass(TASK, ClassificationTask): - val_row.update( - { - "average_accuracy": val_results["val_AverageAccuracy"].item(), - "overall_accuracy": val_results["val_OverallAccuracy"].item(), - } - ) - test_row.update( - { - "average_accuracy": test_results["test_AverageAccuracy"].item(), - "overall_accuracy": test_results["test_OverallAccuracy"].item(), - } - ) - elif issubclass(TASK, SemanticSegmentationTask): - val_row.update( - { - "overall_accuracy": val_results["val_Accuracy"].item(), - "jaccard_index": val_results["val_JaccardIndex"].item(), - } - ) - test_row.update( - { - "overall_accuracy": test_results["test_Accuracy"].item(), - "jaccard_index": test_results["test_JaccardIndex"].item(), - } - ) - elif issubclass(TASK, ObjectDetectionTask): - val_row.update({"map": val_results["map"].item()}) - test_row.update({"map": test_results["map"].item()}) - - assert set(val_row.keys()) == set(test_row.keys()) - fieldnames = list(test_row.keys()) - - # Write to file - if not os.path.exists(args.output_fn): - with open(args.output_fn, "w") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - with open(args.output_fn, "a") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writerow(val_row) - writer.writerow(test_row) - - -if __name__ == "__main__": - parser = set_up_parser() - args = parser.parse_args() - - pl.seed_everything(args.seed) - - main(args) diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 8c8e9bd32c2..c484fb51efc 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -46,9 +46,10 @@ pydocstyle[toml]==6.1.0 pyupgrade==2.8.0 # tests +hydra-core==1.0.0 mypy==0.900 nbmake==1.3.3 -omegaconf==2.1.0 +omegaconf==2.0.1 pytest==6.1.2 pytest-cov==2.4.0 tensorboard==2.9.1 diff --git a/requirements/tests.txt b/requirements/tests.txt index 8330072dde0..284d6f858ab 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,5 @@ # tests +hydra-core==1.3.2 mypy==1.2.0 nbmake==1.4.1 omegaconf==2.3.0 diff --git a/setup.cfg b/setup.cfg index 98028441d90..d4dcefc9a39 100644 --- a/setup.cfg +++ b/setup.cfg @@ -117,12 +117,14 @@ style = # pyupgrade 2.8+ required for --py39-plus flag pyupgrade>=2.8,<4 tests = + # hydra-core 1+ required for omegaconf 2 support + hydra-core>=1 # mypy 0.900+ required for pyproject.toml support mypy>=0.900,<2 # nbmake 1.3.3+ required for variable mocking nbmake>=1.3.3,<2 - # omegaconf 2.1+ required for to_object method - omegaconf>=2.1,<3 + # omegaconf 2+ required by lightning, 2.0.1+ required by hydra-core + omegaconf>=2.0.1 # pytest 6.1.2+ required by nbmake pytest>=6.1.2,<8 # pytest-cov 2.4+ required for pytest --cov flags diff --git a/tests/conf/bigearthnet_all.yaml b/tests/conf/bigearthnet_all.yaml index e885c9db4c7..f034c155b9b 100644 --- a/tests/conf/bigearthnet_all.yaml +++ b/tests/conf/bigearthnet_all.yaml @@ -1,17 +1,18 @@ -experiment: - task: "bigearthnet" - module: - loss: "bce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 14 - num_classes: 19 - datamodule: - root: "tests/data/bigearthnet" - bands: "all" - num_classes: ${experiment.module.num_classes} - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.MultiLabelClassificationTask + loss: "bce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 14 + num_classes: 19 + +datamodule: + _target_: torchgeo.datamodules.BigEarthNetDataModule + root: "tests/data/bigearthnet" + bands: "all" + num_classes: ${module.num_classes} + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/bigearthnet_s1.yaml b/tests/conf/bigearthnet_s1.yaml index 09b71cbd84c..fa49d81c775 100644 --- a/tests/conf/bigearthnet_s1.yaml +++ b/tests/conf/bigearthnet_s1.yaml @@ -1,17 +1,18 @@ -experiment: - task: "bigearthnet" - module: - loss: "bce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 2 - num_classes: 19 - datamodule: - root: "tests/data/bigearthnet" - bands: "s1" - num_classes: ${experiment.module.num_classes} - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.MultiLabelClassificationTask + loss: "bce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 2 + num_classes: 19 + +datamodule: + _target_: torchgeo.datamodules.BigEarthNetDataModule + root: "tests/data/bigearthnet" + bands: "s1" + num_classes: ${module.num_classes} + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/bigearthnet_s2.yaml b/tests/conf/bigearthnet_s2.yaml index 487b1433810..3677de83c79 100644 --- a/tests/conf/bigearthnet_s2.yaml +++ b/tests/conf/bigearthnet_s2.yaml @@ -1,17 +1,18 @@ -experiment: - task: "bigearthnet" - module: - loss: "bce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 12 - num_classes: 19 - datamodule: - root: "tests/data/bigearthnet" - bands: "s2" - num_classes: ${experiment.module.num_classes} - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.MultiLabelClassificationTask + loss: "bce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 12 + num_classes: 19 + +datamodule: + _target_: torchgeo.datamodules.BigEarthNetDataModule + root: "tests/data/bigearthnet" + bands: "s2" + num_classes: ${module.num_classes} + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/chesapeake_cvpr_5.yaml b/tests/conf/chesapeake_cvpr_5.yaml index 7ef269dd661..b4f345c3ab8 100644 --- a/tests/conf/chesapeake_cvpr_5.yaml +++ b/tests/conf/chesapeake_cvpr_5.yaml @@ -1,27 +1,28 @@ -experiment: - task: "chesapeake_cvpr" - module: - loss: "ce" - model: "unet" - backbone: "resnet50" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 4 - num_classes: 5 - num_filters: 1 - ignore_index: null - datamodule: - root: "tests/data/chesapeake/cvpr" - download: true - train_splits: - - "de-test" - val_splits: - - "de-test" - test_splits: - - "de-test" - batch_size: 2 - patch_size: 64 - num_workers: 0 - class_set: ${experiment.module.num_classes} - use_prior_labels: False +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet50" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 4 + num_classes: 5 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule + root: "tests/data/chesapeake/cvpr" + download: true + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" + batch_size: 2 + patch_size: 64 + num_workers: 0 + class_set: ${module.num_classes} + use_prior_labels: False \ No newline at end of file diff --git a/tests/conf/chesapeake_cvpr_7.yaml b/tests/conf/chesapeake_cvpr_7.yaml index 653f4934ca0..634440e680e 100644 --- a/tests/conf/chesapeake_cvpr_7.yaml +++ b/tests/conf/chesapeake_cvpr_7.yaml @@ -1,27 +1,28 @@ -experiment: - task: "chesapeake_cvpr" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 4 - num_classes: 7 - num_filters: 1 - ignore_index: null - weights: null - datamodule: - root: "tests/data/chesapeake/cvpr" - download: true - train_splits: - - "de-test" - val_splits: - - "de-test" - test_splits: - - "de-test" - batch_size: 2 - patch_size: 64 - num_workers: 0 - class_set: ${experiment.module.num_classes} - use_prior_labels: False +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 4 + num_classes: 7 + num_filters: 1 + ignore_index: null + weights: null + +datamodule: + _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule + root: "tests/data/chesapeake/cvpr" + download: true + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" + batch_size: 2 + patch_size: 64 + num_workers: 0 + class_set: ${module.num_classes} + use_prior_labels: False \ No newline at end of file diff --git a/tests/conf/chesapeake_cvpr_prior_byol.yaml b/tests/conf/chesapeake_cvpr_prior_byol.yaml index 3e9713fbb59..6b6841d8f65 100644 --- a/tests/conf/chesapeake_cvpr_prior_byol.yaml +++ b/tests/conf/chesapeake_cvpr_prior_byol.yaml @@ -1,27 +1,23 @@ -experiment: - task: "chesapeake_cvpr" - module: - loss: "ce" - model: "unet" - backbone: "resnet50" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 4 - num_classes: 5 - num_filters: 1 - ignore_index: null - weights: null - datamodule: - root: "tests/data/chesapeake/cvpr" - download: true - train_splits: - - "de-test" - val_splits: - - "de-test" - test_splits: - - "de-test" - batch_size: 2 - patch_size: 64 - num_workers: 0 - class_set: ${experiment.module.num_classes} - use_prior_labels: True +module: + _target_: torchgeo.trainers.BYOLTask + in_channels: 4 + backbone: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + +datamodule: + _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule + root: "tests/data/chesapeake/cvpr" + download: true + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" + batch_size: 2 + patch_size: 64 + num_workers: 0 + class_set: 5 + use_prior_labels: True \ No newline at end of file diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index fc3218e8fef..76eb04763a6 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -1,14 +1,15 @@ -experiment: - task: cowc_counting - module: - model: resnet18 - weights: null - num_outputs: 1 - in_channels: 3 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - datamodule: - root: "tests/data/cowc_counting" - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.RegressionTask + model: resnet18 + weights: null + num_outputs: 1 + in_channels: 3 + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + +datamodule: + _target_: torchgeo.datamodules.COWCCountingDataModule + root: "tests/data/cowc_counting" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index b3323d28999..91a477a144d 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -1,14 +1,15 @@ -experiment: - task: "cyclone" - module: - model: "resnet18" - weights: null - num_outputs: 1 - in_channels: 3 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - datamodule: - root: "tests/data/cyclone" - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.RegressionTask + model: "resnet18" + weights: null + num_outputs: 1 + in_channels: 3 + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + +datamodule: + _target_: torchgeo.datamodules.TropicalCycloneDataModule + root: "tests/data/cyclone" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/deepglobelandcover.yaml b/tests/conf/deepglobelandcover.yaml index e27fe1271c2..09b0f4d9414 100644 --- a/tests/conf/deepglobelandcover.yaml +++ b/tests/conf/deepglobelandcover.yaml @@ -1,20 +1,21 @@ -experiment: - task: "deepglobelandcover" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 3 - num_classes: 7 - num_filters: 1 - ignore_index: null - datamodule: - root: "tests/data/deepglobelandcover" - batch_size: 1 - patch_size: 2 - val_split_pct: 0.5 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.DeepGlobeLandCoverDataModule + root: "tests/data/deepglobelandcover" + batch_size: 1 + patch_size: 2 + val_split_pct: 0.5 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/etci2021.yaml b/tests/conf/etci2021.yaml index cbb766ea522..65c75374431 100644 --- a/tests/conf/etci2021.yaml +++ b/tests/conf/etci2021.yaml @@ -1,17 +1,18 @@ -experiment: - task: "etci2021" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 6 - num_classes: 2 - ignore_index: 0 - datamodule: - root: "tests/data/etci2021" - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 6 + num_classes: 2 + ignore_index: 0 + +datamodule: + _target_: torchgeo.datamodules.ETCI2021DataModule + root: "tests/data/etci2021" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/eurosat.yaml b/tests/conf/eurosat.yaml index a4cbc9eb525..8e39dd50557 100644 --- a/tests/conf/eurosat.yaml +++ b/tests/conf/eurosat.yaml @@ -1,15 +1,16 @@ -experiment: - task: "eurosat" - module: - loss: "ce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 13 - num_classes: 2 - datamodule: - root: "tests/data/eurosat" - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "ce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 13 + num_classes: 2 + +datamodule: + _target_: torchgeo.datamodules.EuroSATDataModule + root: "tests/data/eurosat" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/eurosat100.yaml b/tests/conf/eurosat100.yaml new file mode 100644 index 00000000000..b1e5fe6438b --- /dev/null +++ b/tests/conf/eurosat100.yaml @@ -0,0 +1,16 @@ +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "ce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 13 + num_classes: 2 + +datamodule: + _target_: torchgeo.datamodules.EuroSAT100DataModule + root: "tests/data/eurosat" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/fire_risk.yaml b/tests/conf/fire_risk.yaml index 4c13aeb05fd..0c86285235a 100644 --- a/tests/conf/fire_risk.yaml +++ b/tests/conf/fire_risk.yaml @@ -1,15 +1,16 @@ -experiment: - task: "fire_risk" - module: - loss: "ce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 3 - num_classes: 5 - datamodule: - root: "tests/data/fire_risk" - download: false - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "ce" + model: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 5 + +datamodule: + _target_: torchgeo.datamodules.FireRiskDataModule + root: "tests/data/fire_risk" + download: false + batch_size: 2 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/gid15.yaml b/tests/conf/gid15.yaml index baaea0e1ba2..3af0a01f24e 100644 --- a/tests/conf/gid15.yaml +++ b/tests/conf/gid15.yaml @@ -1,21 +1,22 @@ -experiment: - task: "gid15" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 3 - num_classes: 16 - num_filters: 1 - ignore_index: null - datamodule: - root: "tests/data/gid15" - download: true - batch_size: 1 - patch_size: 2 - val_split_pct: 0.5 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 16 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.GID15DataModule + root: "tests/data/gid15" + download: true + batch_size: 1 + patch_size: 2 + val_split_pct: 0.5 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/inria.yaml b/tests/conf/inria.yaml index 995c073146b..04af3433f1e 100644 --- a/tests/conf/inria.yaml +++ b/tests/conf/inria.yaml @@ -1,19 +1,20 @@ -experiment: - task: "inria" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: "imagenet" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 3 - num_classes: 2 - ignore_index: null - datamodule: - root: "tests/data/inria" - batch_size: 1 - patch_size: 2 - num_workers: 0 - val_split_pct: 0.2 - test_split_pct: 0.2 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: "imagenet" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 2 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.InriaAerialImageLabelingDataModule + root: "tests/data/inria" + batch_size: 1 + patch_size: 2 + num_workers: 0 + val_split_pct: 0.2 + test_split_pct: 0.2 \ No newline at end of file diff --git a/tests/conf/l7irish.yaml b/tests/conf/l7irish.yaml index 1946e80ce2d..cb54362d964 100644 --- a/tests/conf/l7irish.yaml +++ b/tests/conf/l7irish.yaml @@ -1,21 +1,22 @@ -experiment: - task: "l7irish" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 9 - num_classes: 5 - num_filters: 1 - ignore_index: 0 - datamodule: - root: "tests/data/l7irish" - download: true - batch_size: 1 - patch_size: 32 - length: 5 - num_workers: 0 \ No newline at end of file +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 9 + num_classes: 5 + num_filters: 1 + ignore_index: 0 + +datamodule: + _target_: torchgeo.datamodules.L7IrishDataModule + root: "tests/data/l7irish" + download: true + batch_size: 1 + patch_size: 32 + length: 5 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/l8biome.yaml b/tests/conf/l8biome.yaml index b04211074b2..796266d2e24 100644 --- a/tests/conf/l8biome.yaml +++ b/tests/conf/l8biome.yaml @@ -1,21 +1,22 @@ -experiment: - task: "l8biome" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 11 - num_classes: 5 - num_filters: 1 - ignore_index: null - datamodule: - root: "tests/data/l8biome" - download: true - batch_size: 1 - patch_size: 32 - length: 5 - num_workers: 0 \ No newline at end of file +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 11 + num_classes: 5 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.L8BiomeDataModule + root: "tests/data/l8biome" + download: true + batch_size: 1 + patch_size: 32 + length: 5 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/landcoverai.yaml b/tests/conf/landcoverai.yaml index 9bffc96b83d..20ec3653471 100644 --- a/tests/conf/landcoverai.yaml +++ b/tests/conf/landcoverai.yaml @@ -1,19 +1,20 @@ -experiment: - task: "landcoverai" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 3 - num_classes: 6 - num_filters: 1 - ignore_index: null - datamodule: - root: "tests/data/landcoverai" - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 6 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.LandCoverAIDataModule + root: "tests/data/landcoverai" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/loveda.yaml b/tests/conf/loveda.yaml index df062a0e600..92f324cb018 100644 --- a/tests/conf/loveda.yaml +++ b/tests/conf/loveda.yaml @@ -1,19 +1,20 @@ -experiment: - task: "loveda" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 3 - num_classes: 8 - num_filters: 1 - ignore_index: null - datamodule: - root: "tests/data/loveda" - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 8 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.LoveDADataModule + root: "tests/data/loveda" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/naipchesapeake.yaml b/tests/conf/naipchesapeake.yaml index 9cd0e2beb96..06aa921a6dd 100644 --- a/tests/conf/naipchesapeake.yaml +++ b/tests/conf/naipchesapeake.yaml @@ -1,20 +1,21 @@ -experiment: - task: "naipchesapeake" - module: - loss: "ce" - model: "deeplabv3+" - backbone: "resnet34" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - in_channels: 4 - num_classes: 14 - num_filters: 1 - ignore_index: null - datamodule: - naip_root: "tests/data/naip" - chesapeake_root: "tests/data/chesapeake/BAYWIDE" - chesapeake_download: true - batch_size: 2 - num_workers: 0 - patch_size: 32 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "deeplabv3+" + backbone: "resnet34" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + in_channels: 4 + num_classes: 14 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.NAIPChesapeakeDataModule + naip_root: "tests/data/naip" + chesapeake_root: "tests/data/chesapeake/BAYWIDE" + chesapeake_download: true + batch_size: 2 + num_workers: 0 + patch_size: 32 \ No newline at end of file diff --git a/tests/conf/nasa_marine_debris.yaml b/tests/conf/nasa_marine_debris.yaml index 5528b38c39c..01e6de32916 100644 --- a/tests/conf/nasa_marine_debris.yaml +++ b/tests/conf/nasa_marine_debris.yaml @@ -1,14 +1,15 @@ -experiment: - task: "nasa_marine_debris" - module: - model: "faster-rcnn" - backbone: "resnet18" - num_classes: 2 - learning_rate: 1.2e-4 - learning_rate_schedule_patience: 6 - verbose: false - datamodule: - root: "tests/data/nasa_marine_debris" - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.ObjectDetectionTask + model: "faster-rcnn" + backbone: "resnet18" + num_classes: 2 + learning_rate: 1.2e-4 + learning_rate_schedule_patience: 6 + verbose: false + +datamodule: + _target_: torchgeo.datamodules.NASAMarineDebrisDataModule + root: "tests/data/nasa_marine_debris" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/potsdam2d.yaml b/tests/conf/potsdam2d.yaml index 7492a8c0c86..9ac40d93681 100644 --- a/tests/conf/potsdam2d.yaml +++ b/tests/conf/potsdam2d.yaml @@ -1,20 +1,21 @@ -experiment: - task: "potsdam2d" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 4 - num_classes: 6 - num_filters: 1 - ignore_index: null - datamodule: - root: "tests/data/potsdam" - batch_size: 1 - patch_size: 2 - val_split_pct: 0.5 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 4 + num_classes: 6 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.Potsdam2DDataModule + root: "tests/data/potsdam" + batch_size: 1 + patch_size: 2 + val_split_pct: 0.5 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/resisc45.yaml b/tests/conf/resisc45.yaml index fd354ad09f8..7dee7bc43fe 100644 --- a/tests/conf/resisc45.yaml +++ b/tests/conf/resisc45.yaml @@ -1,15 +1,16 @@ -experiment: - task: "resisc45" - module: - loss: "ce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 3 - num_classes: 3 - datamodule: - root: "tests/data/resisc45" - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "ce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 3 + num_classes: 3 + +datamodule: + _target_: torchgeo.datamodules.RESISC45DataModule + root: "tests/data/resisc45" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/seco_byol_1.yaml b/tests/conf/seco_byol_1.yaml index 50379b07f68..5f7e0b91b20 100644 --- a/tests/conf/seco_byol_1.yaml +++ b/tests/conf/seco_byol_1.yaml @@ -1,13 +1,14 @@ -experiment: - task: "seco" - module: - in_channels: 3 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - datamodule: - root: "tests/data/seco" - seasons: 1 - batch_size: 2 - num_workers: 0 +module: + _target_: torchgeo.trainers.BYOLTask + in_channels: 3 + backbone: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + +datamodule: + _target_: torchgeo.datamodules.SeasonalContrastS2DataModule + root: "tests/data/seco" + seasons: 1 + batch_size: 2 + num_workers: 0 diff --git a/tests/conf/seco_byol_2.yaml b/tests/conf/seco_byol_2.yaml index e07354cb2a2..07ff81c0132 100644 --- a/tests/conf/seco_byol_2.yaml +++ b/tests/conf/seco_byol_2.yaml @@ -1,13 +1,14 @@ -experiment: - task: "seco" - module: - in_channels: 3 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - datamodule: - root: "tests/data/seco" - seasons: 2 - batch_size: 2 - num_workers: 0 +module: + _target_: torchgeo.trainers.BYOLTask + in_channels: 3 + backbone: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + +datamodule: + _target_: torchgeo.datamodules.SeasonalContrastS2DataModule + root: "tests/data/seco" + seasons: 2 + batch_size: 2 + num_workers: 0 diff --git a/tests/conf/sen12ms_all.yaml b/tests/conf/sen12ms_all.yaml index e5676876550..0bdbc54ddff 100644 --- a/tests/conf/sen12ms_all.yaml +++ b/tests/conf/sen12ms_all.yaml @@ -1,17 +1,18 @@ -experiment: - task: "sen12ms" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - in_channels: 15 - num_classes: 11 - ignore_index: null - datamodule: - root: "tests/data/sen12ms" - band_set: "all" - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + in_channels: 15 + num_classes: 11 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.SEN12MSDataModule + root: "tests/data/sen12ms" + band_set: "all" + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/sen12ms_s1.yaml b/tests/conf/sen12ms_s1.yaml index 5289c3c8b63..8cf4435c624 100644 --- a/tests/conf/sen12ms_s1.yaml +++ b/tests/conf/sen12ms_s1.yaml @@ -1,18 +1,19 @@ -experiment: - task: "sen12ms" - module: - loss: "focal" - model: "fcn" - num_filters: 1 - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - in_channels: 2 - num_classes: 11 - ignore_index: null - datamodule: - root: "tests/data/sen12ms" - band_set: "s1" - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "focal" + model: "fcn" + num_filters: 1 + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + in_channels: 2 + num_classes: 11 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.SEN12MSDataModule + root: "tests/data/sen12ms" + band_set: "s1" + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/sen12ms_s2_all.yaml b/tests/conf/sen12ms_s2_all.yaml index f1499b523e3..a7712cf4a78 100644 --- a/tests/conf/sen12ms_s2_all.yaml +++ b/tests/conf/sen12ms_s2_all.yaml @@ -1,17 +1,18 @@ -experiment: - task: "sen12ms" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - in_channels: 13 - num_classes: 11 - ignore_index: null - datamodule: - root: "tests/data/sen12ms" - band_set: "s2-all" - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + in_channels: 13 + num_classes: 11 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.SEN12MSDataModule + root: "tests/data/sen12ms" + band_set: "s2-all" + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/sen12ms_s2_reduced.yaml b/tests/conf/sen12ms_s2_reduced.yaml index 72e85b56fc3..9493519da2d 100644 --- a/tests/conf/sen12ms_s2_reduced.yaml +++ b/tests/conf/sen12ms_s2_reduced.yaml @@ -1,17 +1,18 @@ -experiment: - task: "sen12ms" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - in_channels: 6 - num_classes: 11 - ignore_index: null - datamodule: - root: "tests/data/sen12ms" - band_set: "s2-reduced" - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + in_channels: 6 + num_classes: 11 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.SEN12MSDataModule + root: "tests/data/sen12ms" + band_set: "s2-reduced" + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/skippd.yaml b/tests/conf/skippd.yaml index 8f1c1cb655f..20ca10f24a6 100644 --- a/tests/conf/skippd.yaml +++ b/tests/conf/skippd.yaml @@ -1,14 +1,15 @@ -experiment: - task: "skippd" - module: - model: "resnet18" - weights: null - num_outputs: 1 - in_channels: 3 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - datamodule: - root: "tests/data/skippd" - download: true - batch_size: 1 - num_workers: 0 \ No newline at end of file +module: + _target_: torchgeo.trainers.RegressionTask + model: "resnet18" + weights: null + num_outputs: 1 + in_channels: 3 + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + +datamodule: + _target_: torchgeo.datamodules.SKIPPDDataModule + root: "tests/data/skippd" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/so2sat_all.yaml b/tests/conf/so2sat_all.yaml index a8d8c0bb8e3..1033918e0ff 100644 --- a/tests/conf/so2sat_all.yaml +++ b/tests/conf/so2sat_all.yaml @@ -1,15 +1,16 @@ -experiment: - task: "so2sat" - module: - loss: "ce" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 18 - num_classes: 17 - datamodule: - root: "tests/data/so2sat" - batch_size: 1 - num_workers: 0 - band_set: "all" +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "ce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 18 + num_classes: 17 + +datamodule: + _target_: torchgeo.datamodules.So2SatDataModule + root: "tests/data/so2sat" + batch_size: 1 + num_workers: 0 + band_set: "all" \ No newline at end of file diff --git a/tests/conf/so2sat_s1.yaml b/tests/conf/so2sat_s1.yaml index 8c87ff55a53..44a437d0ec5 100644 --- a/tests/conf/so2sat_s1.yaml +++ b/tests/conf/so2sat_s1.yaml @@ -1,15 +1,16 @@ -experiment: - task: "so2sat" - module: - loss: "focal" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 8 - num_classes: 17 - datamodule: - root: "tests/data/so2sat" - batch_size: 1 - num_workers: 0 - band_set: "s1" +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "focal" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 8 + num_classes: 17 + +datamodule: + _target_: torchgeo.datamodules.So2SatDataModule + root: "tests/data/so2sat" + batch_size: 1 + num_workers: 0 + band_set: "s1" \ No newline at end of file diff --git a/tests/conf/so2sat_s2.yaml b/tests/conf/so2sat_s2.yaml index ab9c573a197..b7474bc7705 100644 --- a/tests/conf/so2sat_s2.yaml +++ b/tests/conf/so2sat_s2.yaml @@ -1,15 +1,16 @@ -experiment: - task: "so2sat" - module: - loss: "jaccard" - model: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - in_channels: 10 - num_classes: 17 - datamodule: - root: "tests/data/so2sat" - batch_size: 1 - num_workers: 0 - band_set: "s2" +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "jaccard" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 10 + num_classes: 17 + +datamodule: + _target_: torchgeo.datamodules.So2SatDataModule + root: "tests/data/so2sat" + batch_size: 1 + num_workers: 0 + band_set: "s2" \ No newline at end of file diff --git a/tests/conf/spacenet1.yaml b/tests/conf/spacenet1.yaml index 3f05a745573..e4feb50a37e 100644 --- a/tests/conf/spacenet1.yaml +++ b/tests/conf/spacenet1.yaml @@ -1,21 +1,22 @@ -experiment: - task: "spacenet1" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 3 - num_classes: 3 - num_filters: 1 - ignore_index: null - datamodule: - root: "tests/data/spacenet" - download: true - batch_size: 1 - num_workers: 0 - val_split_pct: 0.33 - test_split_pct: 0.33 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 3 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.SpaceNet1DataModule + root: "tests/data/spacenet" + download: true + batch_size: 1 + num_workers: 0 + val_split_pct: 0.33 + test_split_pct: 0.33 \ No newline at end of file diff --git a/tests/conf/ssl4eo_s12_byol_1.yaml b/tests/conf/ssl4eo_s12_byol_1.yaml index f9b99601efc..0bc3267ecc0 100644 --- a/tests/conf/ssl4eo_s12_byol_1.yaml +++ b/tests/conf/ssl4eo_s12_byol_1.yaml @@ -1,13 +1,14 @@ -experiment: - task: "ssl4eo_s12" - module: - in_channels: 13 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - datamodule: - root: "tests/data/ssl4eo/s12" - seasons: 1 - batch_size: 2 - num_workers: 0 +module: + _target_: torchgeo.trainers.BYOLTask + in_channels: 13 + backbone: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + +datamodule: + _target_: torchgeo.datamodules.SSL4EOS12DataModule + root: "tests/data/ssl4eo/s12" + seasons: 1 + batch_size: 2 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/ssl4eo_s12_byol_2.yaml b/tests/conf/ssl4eo_s12_byol_2.yaml index 7679454bf93..cced864fc6e 100644 --- a/tests/conf/ssl4eo_s12_byol_2.yaml +++ b/tests/conf/ssl4eo_s12_byol_2.yaml @@ -1,13 +1,14 @@ -experiment: - task: "ssl4eo_s12" - module: - in_channels: 13 - backbone: "resnet18" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - weights: null - datamodule: - root: "tests/data/ssl4eo/s12" - seasons: 2 - batch_size: 2 - num_workers: 0 +module: + _target_: torchgeo.trainers.BYOLTask + in_channels: 13 + backbone: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + +datamodule: + _target_: torchgeo.datamodules.SSL4EOS12DataModule + root: "tests/data/ssl4eo/s12" + seasons: 2 + batch_size: 2 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/sustainbench_crop_yield.yaml b/tests/conf/sustainbench_crop_yield.yaml index 60903ea7d4c..2f48a83d02a 100644 --- a/tests/conf/sustainbench_crop_yield.yaml +++ b/tests/conf/sustainbench_crop_yield.yaml @@ -1,14 +1,15 @@ -experiment: - task: "sustainbench_crop_yield" - module: - model: "resnet18" - weights: null - num_outputs: 1 - in_channels: 9 - learning_rate: 1e-3 - learning_rate_schedule_patience: 2 - datamodule: - root: "tests/data/sustainbench_crop_yield" - download: true - batch_size: 1 - num_workers: 0 +module: + _target_: torchgeo.trainers.RegressionTask + model: "resnet18" + weights: null + num_outputs: 1 + in_channels: 9 + learning_rate: 1e-3 + learning_rate_schedule_patience: 2 + +datamodule: + _target_: torchgeo.datamodules.SustainBenchCropYieldDataModule + root: "tests/data/sustainbench_crop_yield" + download: true + batch_size: 1 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml index 3c544564ae8..22f61ff7cd0 100644 --- a/tests/conf/ucmerced.yaml +++ b/tests/conf/ucmerced.yaml @@ -1,15 +1,16 @@ -experiment: - task: "ucmerced" - module: - loss: "ce" - model: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 3 - num_classes: 2 - datamodule: - root: "tests/data/ucmerced" - download: true - batch_size: 2 - num_workers: 0 +module: + _target_: torchgeo.trainers.ClassificationTask + loss: "ce" + model: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 2 + +datamodule: + _target_: torchgeo.datamodules.UCMercedDataModule + root: "tests/data/ucmerced" + download: true + batch_size: 2 + num_workers: 0 \ No newline at end of file diff --git a/tests/conf/vaihingen2d.yaml b/tests/conf/vaihingen2d.yaml index 7f542f3310b..8bd3043a673 100644 --- a/tests/conf/vaihingen2d.yaml +++ b/tests/conf/vaihingen2d.yaml @@ -1,20 +1,21 @@ -experiment: - task: "vaihingen2d" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 3 - num_classes: 7 - num_filters: 1 - ignore_index: null - datamodule: - root: "tests/data/vaihingen" - batch_size: 1 - patch_size: 2 - val_split_pct: 0.5 - num_workers: 0 +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: null + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + verbose: false + in_channels: 3 + num_classes: 7 + num_filters: 1 + ignore_index: null + +datamodule: + _target_: torchgeo.datamodules.Vaihingen2DDataModule + root: "tests/data/vaihingen" + batch_size: 1 + patch_size: 2 + val_split_pct: 0.5 + num_workers: 0 \ No newline at end of file diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 67bfef7d5f1..4b0e9957c86 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from typing import Any, cast +from typing import Any import pytest import timm @@ -12,16 +12,12 @@ import torchvision from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from lightning.pytorch import LightningDataModule, Trainer +from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import OmegaConf from torchvision.models import resnet18 from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import ( - ChesapeakeCVPRDataModule, - SeasonalContrastS2DataModule, - SSL4EOS12DataModule, -) from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2 from torchgeo.models import get_model_weights, list_models from torchgeo.trainers import BYOLTask @@ -54,25 +50,19 @@ def test_custom_augment_fn(self) -> None: class TestBYOLTask: @pytest.mark.parametrize( - "name,classname", + "name", [ - ("chesapeake_cvpr_prior_byol", ChesapeakeCVPRDataModule), - ("seco_byol_1", SeasonalContrastS2DataModule), - ("seco_byol_2", SeasonalContrastS2DataModule), - ("ssl4eo_s12_byol_1", SSL4EOS12DataModule), - ("ssl4eo_s12_byol_2", SSL4EOS12DataModule), + "chesapeake_cvpr_prior_byol", + "seco_byol_1", + "seco_byol_2", + "ssl4eo_s12_byol_1", + "ssl4eo_s12_byol_2", ], ) def test_trainer( - self, - monkeypatch: MonkeyPatch, - name: str, - classname: type[LightningDataModule], - fast_dev_run: bool, + self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(dict[str, dict[str, Any]], conf_dict) if name.startswith("seco"): monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2) @@ -81,14 +71,11 @@ def test_trainer( monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2) # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = classname(**datamodule_kwargs) + datamodule = instantiate(conf.datamodule) # Instantiate model - model_kwargs = conf_dict["module"] - model = BYOLTask(**model_kwargs) - - model.backbone = SegmentationTestModel(**model_kwargs) + model = instantiate(conf.module) + model.backbone = SegmentationTestModel(**conf.module) # Instantiate trainer trainer = Trainer( diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 4f8bcbe6f90..4abdf95bb87 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from typing import Any, cast +from typing import Any import pytest import timm @@ -12,20 +12,16 @@ import torchvision from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from lightning.pytorch import LightningDataModule, Trainer +from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import OmegaConf from torch.nn.modules import Module from torchvision.models._api import WeightsEnum from torchgeo.datamodules import ( BigEarthNetDataModule, - EuroSAT100DataModule, EuroSATDataModule, - FireRiskDataModule, MisconfigurationException, - RESISC45DataModule, - So2SatDataModule, - UCMercedDataModule, ) from torchgeo.datasets import BigEarthNet, EuroSAT from torchgeo.models import get_model_weights, list_models @@ -33,9 +29,7 @@ class ClassificationTestModel(Module): - def __init__( - self, in_chans: int = 3, num_classes: int = 1000, **kwargs: Any - ) -> None: + def __init__(self, in_chans: int = 3, num_classes: int = 10, **kwargs: Any) -> None: super().__init__() self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=1, kernel_size=1) self.pool = nn.AdaptiveAvgPool2d((1, 1)) @@ -74,40 +68,32 @@ def plot(*args: Any, **kwargs: Any) -> None: class TestClassificationTask: @pytest.mark.parametrize( - "name,classname", + "name", [ - ("eurosat", EuroSATDataModule), - ("eurosat", EuroSAT100DataModule), - ("fire_risk", FireRiskDataModule), - ("resisc45", RESISC45DataModule), - ("so2sat_all", So2SatDataModule), - ("so2sat_s1", So2SatDataModule), - ("so2sat_s2", So2SatDataModule), - ("ucmerced", UCMercedDataModule), + "eurosat", + "eurosat100", + "fire_risk", + "resisc45", + "so2sat_all", + "so2sat_s1", + "so2sat_s2", + "ucmerced", ], ) def test_trainer( - self, - monkeypatch: MonkeyPatch, - name: str, - classname: type[LightningDataModule], - fast_dev_run: bool, + self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: if name.startswith("so2sat"): pytest.importorskip("h5py", minversion="2.6") conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(dict[str, dict[str, Any]], conf_dict) # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = classname(**datamodule_kwargs) + datamodule = instantiate(conf.datamodule) # Instantiate model monkeypatch.setattr(timm, "create_model", create_model) - model_kwargs = conf_dict["module"] - model = ClassificationTask(**model_kwargs) + model = instantiate(conf.module) # Instantiate trainer trainer = Trainer( @@ -239,32 +225,19 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None class TestMultiLabelClassificationTask: @pytest.mark.parametrize( - "name,classname", - [ - ("bigearthnet_all", BigEarthNetDataModule), - ("bigearthnet_s1", BigEarthNetDataModule), - ("bigearthnet_s2", BigEarthNetDataModule), - ], + "name", ["bigearthnet_all", "bigearthnet_s1", "bigearthnet_s2"] ) def test_trainer( - self, - monkeypatch: MonkeyPatch, - name: str, - classname: type[LightningDataModule], - fast_dev_run: bool, + self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(dict[str, dict[str, Any]], conf_dict) # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = classname(**datamodule_kwargs) + datamodule = instantiate(conf.datamodule) # Instantiate model monkeypatch.setattr(timm, "create_model", create_model) - model_kwargs = conf_dict["module"] - model = MultiLabelClassificationTask(**model_kwargs) + model = instantiate(conf.module) # Instantiate trainer trainer = Trainer( diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 48b8b0d4579..38419ea0c05 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -2,14 +2,15 @@ # Licensed under the MIT License. import os -from typing import Any, cast +from typing import Any import pytest import torch import torch.nn as nn import torchvision.models.detection from _pytest.monkeypatch import MonkeyPatch -from lightning.pytorch import LightningDataModule, Trainer +from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import OmegaConf from torch.nn.modules import Module @@ -57,25 +58,15 @@ def plot(*args: Any, **kwargs: Any) -> None: class TestObjectDetectionTask: - @pytest.mark.parametrize( - "name,classname", [("nasa_marine_debris", NASAMarineDebrisDataModule)] - ) + @pytest.mark.parametrize("name", ["nasa_marine_debris"]) @pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"]) def test_trainer( - self, - monkeypatch: MonkeyPatch, - name: str, - classname: type[LightningDataModule], - model_name: str, - fast_dev_run: bool, + self, monkeypatch: MonkeyPatch, name: str, model_name: str, fast_dev_run: bool ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", f"{name}.yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(dict[Any, dict[Any, Any]], conf_dict) # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = classname(**datamodule_kwargs) + datamodule = instantiate(conf.datamodule) # Instantiate model monkeypatch.setattr( @@ -87,9 +78,8 @@ def test_trainer( monkeypatch.setattr( torchvision.models.detection, "RetinaNet", ObjectDetectionTestModel ) - model_kwargs = conf_dict["module"] - model_kwargs["model"] = model_name - model = ObjectDetectionTask(**model_kwargs) + conf.module.model = model_name + model = instantiate(conf.module) # Instantiate trainer trainer = Trainer( diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index f4210a7cfa9..cabfe5a2cdc 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from typing import Any, cast +from typing import Any import pytest import timm @@ -11,17 +11,12 @@ import torchvision from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from lightning.pytorch import LightningDataModule, Trainer +from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import OmegaConf from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import ( - COWCCountingDataModule, - MisconfigurationException, - SKIPPDDataModule, - SustainBenchCropYieldDataModule, - TropicalCycloneDataModule, -) +from torchgeo.datamodules import MisconfigurationException, TropicalCycloneDataModule from torchgeo.datasets import TropicalCyclone from torchgeo.models import get_model_weights, list_models from torchgeo.trainers import RegressionTask @@ -50,32 +45,19 @@ def plot(*args: Any, **kwargs: Any) -> None: class TestRegressionTask: @pytest.mark.parametrize( - "name,classname", - [ - ("cowc_counting", COWCCountingDataModule), - ("cyclone", TropicalCycloneDataModule), - ("sustainbench_crop_yield", SustainBenchCropYieldDataModule), - ("skippd", SKIPPDDataModule), - ], + "name", ["cowc_counting", "cyclone", "sustainbench_crop_yield", "skippd"] ) - def test_trainer( - self, name: str, classname: type[LightningDataModule], fast_dev_run: bool - ) -> None: + def test_trainer(self, name: str, fast_dev_run: bool) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(dict[str, dict[str, Any]], conf_dict) # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = classname(**datamodule_kwargs) + datamodule = instantiate(conf.datamodule) # Instantiate model - model_kwargs = conf_dict["module"] - model = RegressionTask(**model_kwargs) + model = instantiate(conf.module) model.model = RegressionTestModel( - in_chans=model_kwargs["in_channels"], - num_classes=model_kwargs["num_outputs"], + in_chans=conf.module.in_channels, num_classes=conf.module.num_outputs ) # Instantiate trainer diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 80a3404e68a..eb34204b6e3 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -9,35 +9,18 @@ import torch import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch -from lightning.pytorch import LightningDataModule, Trainer +from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import OmegaConf from torch.nn.modules import Module -from torchgeo.datamodules import ( - ChesapeakeCVPRDataModule, - DeepGlobeLandCoverDataModule, - ETCI2021DataModule, - GID15DataModule, - InriaAerialImageLabelingDataModule, - L7IrishDataModule, - L8BiomeDataModule, - LandCoverAIDataModule, - LoveDADataModule, - MisconfigurationException, - NAIPChesapeakeDataModule, - Potsdam2DDataModule, - SEN12MSDataModule, - SpaceNet1DataModule, - Vaihingen2DDataModule, -) +from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule from torchgeo.datasets import LandCoverAI from torchgeo.trainers import SemanticSegmentationTask class SegmentationTestModel(Module): - def __init__( - self, in_channels: int = 3, classes: int = 1000, **kwargs: Any - ) -> None: + def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None: super().__init__() self.conv1 = nn.Conv2d( in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0 @@ -57,34 +40,30 @@ def plot(*args: Any, **kwargs: Any) -> None: class TestSemanticSegmentationTask: @pytest.mark.parametrize( - "name,classname", + "name", [ - ("chesapeake_cvpr_5", ChesapeakeCVPRDataModule), - ("chesapeake_cvpr_7", ChesapeakeCVPRDataModule), - ("deepglobelandcover", DeepGlobeLandCoverDataModule), - ("etci2021", ETCI2021DataModule), - ("gid15", GID15DataModule), - ("inria", InriaAerialImageLabelingDataModule), - ("l7irish", L7IrishDataModule), - ("l8biome", L8BiomeDataModule), - ("landcoverai", LandCoverAIDataModule), - ("loveda", LoveDADataModule), - ("naipchesapeake", NAIPChesapeakeDataModule), - ("potsdam2d", Potsdam2DDataModule), - ("sen12ms_all", SEN12MSDataModule), - ("sen12ms_s1", SEN12MSDataModule), - ("sen12ms_s2_all", SEN12MSDataModule), - ("sen12ms_s2_reduced", SEN12MSDataModule), - ("spacenet1", SpaceNet1DataModule), - ("vaihingen2d", Vaihingen2DDataModule), + "chesapeake_cvpr_5", + "chesapeake_cvpr_7", + "deepglobelandcover", + "etci2021", + "gid15", + "inria", + "l7irish", + "l8biome", + "landcoverai", + "loveda", + "naipchesapeake", + "potsdam2d", + "sen12ms_all", + "sen12ms_s1", + "sen12ms_s2_all", + "sen12ms_s2_reduced", + "spacenet1", + "vaihingen2d", ], ) def test_trainer( - self, - monkeypatch: MonkeyPatch, - name: str, - classname: type[LightningDataModule], - fast_dev_run: bool, + self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: if name == "naipchesapeake": pytest.importorskip("zipfile_deflate64") @@ -94,18 +73,14 @@ def test_trainer( monkeypatch.setattr(LandCoverAI, "sha256", sha256) conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(dict[Any, dict[Any, Any]], conf_dict) # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = classname(**datamodule_kwargs) + datamodule = instantiate(conf.datamodule) # Instantiate model monkeypatch.setattr(smp, "Unet", create_model) monkeypatch.setattr(smp, "DeepLabV3Plus", create_model) - model_kwargs = conf_dict["module"] - model = SemanticSegmentationTask(**model_kwargs) + model = instantiate(conf.module) # Instantiate trainer trainer = Trainer( diff --git a/train.py b/train.py index de2e6f85691..0722a3d7b96 100755 --- a/train.py +++ b/train.py @@ -6,70 +6,17 @@ """torchgeo model training script.""" import os -from typing import Any, cast +from typing import cast import lightning.pytorch as pl +from hydra.utils import instantiate from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from omegaconf import DictConfig, OmegaConf -from torchgeo.datamodules import ( - BigEarthNetDataModule, - ChesapeakeCVPRDataModule, - COWCCountingDataModule, - DeepGlobeLandCoverDataModule, - ETCI2021DataModule, - EuroSATDataModule, - GID15DataModule, - InriaAerialImageLabelingDataModule, - LandCoverAIDataModule, - LoveDADataModule, - NAIPChesapeakeDataModule, - NASAMarineDebrisDataModule, - Potsdam2DDataModule, - RESISC45DataModule, - SEN12MSDataModule, - So2SatDataModule, - SpaceNet1DataModule, - TropicalCycloneDataModule, - UCMercedDataModule, - Vaihingen2DDataModule, -) -from torchgeo.trainers import ( - BYOLTask, - ClassificationTask, - MultiLabelClassificationTask, - ObjectDetectionTask, - RegressionTask, - SemanticSegmentationTask, -) - -TASK_TO_MODULES_MAPPING: dict[ - str, tuple[type[LightningModule], type[LightningDataModule]] -] = { - "bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule), - "byol": (BYOLTask, ChesapeakeCVPRDataModule), - "chesapeake_cvpr": (SemanticSegmentationTask, ChesapeakeCVPRDataModule), - "cowc_counting": (RegressionTask, COWCCountingDataModule), - "cyclone": (RegressionTask, TropicalCycloneDataModule), - "deepglobelandcover": (SemanticSegmentationTask, DeepGlobeLandCoverDataModule), - "eurosat": (ClassificationTask, EuroSATDataModule), - "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), - "gid15": (SemanticSegmentationTask, GID15DataModule), - "inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule), - "landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule), - "loveda": (SemanticSegmentationTask, LoveDADataModule), - "naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule), - "nasa_marine_debris": (ObjectDetectionTask, NASAMarineDebrisDataModule), - "potsdam2d": (SemanticSegmentationTask, Potsdam2DDataModule), - "resisc45": (ClassificationTask, RESISC45DataModule), - "sen12ms": (SemanticSegmentationTask, SEN12MSDataModule), - "so2sat": (ClassificationTask, So2SatDataModule), - "spacenet1": (SemanticSegmentationTask, SpaceNet1DataModule), - "ucmerced": (ClassificationTask, UCMercedDataModule), - "vaihingen2d": (SemanticSegmentationTask, Vaihingen2DDataModule), -} +from torchgeo.datamodules import MisconfigurationException +from torchgeo.trainers import BYOLTask, ObjectDetectionTask def set_up_omegaconf() -> DictConfig: @@ -91,7 +38,6 @@ def set_up_omegaconf() -> DictConfig: Raises: FileNotFoundError: when ``config_file`` does not exist - ValueError: when ``task.name`` is not a valid task """ conf = OmegaConf.load("conf/defaults.yaml") command_line_conf = OmegaConf.from_cli() @@ -107,34 +53,15 @@ def set_up_omegaconf() -> DictConfig: conf = OmegaConf.merge( # Merge in any arguments passed via the command line conf, command_line_conf ) - - # These OmegaConf structured configs enforce a schema at runtime, see: - # https://omegaconf.readthedocs.io/en/2.0_branch/structured_config.html#merging-with-other-configs - task_name = conf.experiment.task - task_config_fn = os.path.join("conf", f"{task_name}.yaml") - if task_name == "test": - task_conf = OmegaConf.create() - elif os.path.exists(task_config_fn): - task_conf = cast(DictConfig, OmegaConf.load(task_config_fn)) - else: - raise ValueError( - f"experiment.task={task_name} is not recognized as a valid task" - ) - - conf = OmegaConf.merge(task_conf, conf) conf = cast(DictConfig, conf) # convince mypy that everything is alright - return conf def main(conf: DictConfig) -> None: """Main training loop.""" - ###################################### - # Setup output directory - ###################################### - - experiment_name = conf.experiment.name - task_name = conf.experiment.task + experiment_name = ( + f"{conf.datamodule._target_.lower()}_{conf.module._target_.lower()}" + ) if os.path.isfile(conf.program.output_dir): raise NotADirectoryError("`program.output_dir` must be a directory") os.makedirs(conf.program.output_dir, exist_ok=True) @@ -154,45 +81,30 @@ def main(conf: DictConfig) -> None: + "empty. We don't want to overwrite any existing results, exiting..." ) - with open(os.path.join(experiment_dir, "experiment_config.yaml"), "w") as f: + with open(os.path.join(experiment_dir, "config.yaml"), "w") as f: OmegaConf.save(config=conf, f=f) - ###################################### - # Choose task to run based on arguments or configuration - ###################################### - # Convert the DictConfig into a dictionary so that we can pass as kwargs. - task_args = cast(dict[str, Any], OmegaConf.to_object(conf.experiment.module)) - datamodule_args = cast( - dict[str, Any], OmegaConf.to_object(conf.experiment.datamodule) - ) + # Define module and datamodule + datamodule: LightningDataModule = instantiate(conf.datamodule) + task: LightningModule = instantiate(conf.module) - datamodule: LightningDataModule - task: LightningModule - if task_name in TASK_TO_MODULES_MAPPING: - task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name] - task = task_class(**task_args) - datamodule = datamodule_class(**datamodule_args) - else: - raise ValueError( - f"experiment.task={task_name} is not recognized as a valid task" - ) - - ###################################### - # Setup trainer - ###################################### + # Define callbacks tb_logger = TensorBoardLogger(conf.program.log_dir, name=experiment_name) csv_logger = CSVLogger(conf.program.log_dir, name=experiment_name) if isinstance(task, ObjectDetectionTask): monitor_metric = "val_map" mode = "max" + elif isinstance(task, BYOLTask): + monitor_metric = "train_loss" + mode = "min" else: monitor_metric = "val_loss" mode = "min" checkpoint_callback = ModelCheckpoint( monitor=monitor_metric, - filename="checkpoint-epoch{epoch:02d}-val_loss{val_loss:.2f}", + filename=f"checkpoint-{{epoch:02d}}-{{{monitor_metric}:.2f}}", dirpath=experiment_dir, save_top_k=1, save_last=True, @@ -202,18 +114,22 @@ def main(conf: DictConfig) -> None: monitor=monitor_metric, min_delta=0.00, patience=18, mode=mode ) - trainer_args = cast(dict[str, Any], OmegaConf.to_object(conf.trainer)) - - trainer_args["callbacks"] = [checkpoint_callback, early_stopping_callback] - trainer_args["logger"] = [tb_logger, csv_logger] - trainer_args["default_root_dir"] = experiment_dir - trainer = Trainer(**trainer_args) + # Define trainer + trainer: Trainer = instantiate( + conf.trainer, + callbacks=[checkpoint_callback, early_stopping_callback], + logger=[tb_logger, csv_logger], + default_root_dir=experiment_dir, + ) - ###################################### - # Run experiment - ###################################### + # Train trainer.fit(model=task, datamodule=datamodule) - trainer.test(ckpt_path="best", datamodule=datamodule) + + # Test + try: + trainer.test(ckpt_path="best", datamodule=datamodule) + except MisconfigurationException: + pass if __name__ == "__main__":