Skip to content

Commit

Permalink
Refactor train.py (#1237)
Browse files Browse the repository at this point in the history
* add pretrain.py tested with seco100k

* refactor

* add pretrain.py tested with seco100k

* refactor

* refactor to train.py and add simclr

* revert simclr and pretrain.py changes

* revert more simclr changesg

* add refactor to configs and train.py

* add hydra.utils.instantiate import

* fix flake8

* update tests and add hydra-core to deps

* fix byol tests

* update exp name

* format

* remove evaluate.py

* add hydra-core to min tests deps

* update tests

* add trainer to configs and use lightning.Trainer insteead of pl.Trainer

* add eurosat100 test

* update train.py

* lightning.Trainer -> lightning.pytorch.Trainer

* remove omegaconf

* update hydra-core to 1.1.1 and fix mypy

* add recursive hydra config test

* update coment

* update test config

* fix tests

* add omegaconf back in

* remove hydra recursive test

* update hydra-core to 2.3.0 for ci

* Try older hydra

* Older hydra requires old omegaconf

* Try older hydra

* Try newer hydra

* Try older hydra

* Try newer hydra

* Finalize minimum dep versions

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
isaaccorley and adamjstewart committed Apr 23, 2023
1 parent 98d8666 commit d3c82a5
Show file tree
Hide file tree
Showing 73 changed files with 1,292 additions and 1,643 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ repos:
hooks:
- id: mypy
args: [--strict, --ignore-missing-imports, --show-error-codes]
additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=1.8, pytest>=6, pyvista>=0.20, omegaconf>=2.1, kornia>=0.6, numpy>=1.22.0]
additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=1.8, pytest>=6, pyvista>=0.20, omegaconf>=2.0.1, hydra-core>=1, kornia>=0.6, numpy>=1.22]
exclude: (build|data|dist|logo|logs|output)/
40 changes: 21 additions & 19 deletions conf/bigearthnet.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
module:
_target_: torchgeo.trainers.MultiLabelClassificationTask
loss: "bce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
in_channels: 14
num_classes: 19

datamodule:
_target_: torchgeo.datamodules.BigEarthNetDataModule
root: "data/bigearthnet"
bands: "all"
num_classes: ${module.num_classes}
batch_size: 128
num_workers: 4

trainer:
_target_: lightning.pytorch.Trainer
accelerator: gpu
devices: 1
min_epochs: 10
max_epochs: 40
benchmark: True
experiment:
task: "bigearthnet"
module:
loss: "bce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
in_channels: 14
num_classes: 19
datamodule:
root: "data/bigearthnet"
bands: "all"
num_classes: ${experiment.module.num_classes}
batch_size: 128
num_workers: 4
min_epochs: 15
max_epochs: 40
26 changes: 0 additions & 26 deletions conf/byol.yaml

This file was deleted.

61 changes: 31 additions & 30 deletions conf/chesapeake_cvpr.yaml
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
module:
_target_: torchgeo.trainers.SemanticSegmentationTask
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 4
num_classes: 7
num_filters: 256
ignore_index: null

datamodule:
_target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
root: "data/chesapeake/cvpr"
train_splits:
- "de-train"
val_splits:
- "de-val"
test_splits:
- "de-test"
batch_size: 200
patch_size: 256
num_workers: 4
class_set: ${module.num_classes}
use_prior_labels: False

trainer:
_target_: lightning.pytorch.Trainer
accelerator: gpu
devices: 1
min_epochs: 20
max_epochs: 100
benchmark: True
experiment:
task: "chesapeake_cvpr"
name: "chesapeake_cvpr_example"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 4
num_classes: 7
num_filters: 256
ignore_index: null
datamodule:
root: "data/chesapeake/cvpr"
train_splits:
- "de-train"
val_splits:
- "de-val"
test_splits:
- "de-test"
batch_size: 200
patch_size: 256
num_workers: 4
class_set: ${experiment.module.num_classes}
use_prior_labels: False
min_epochs: 15
max_epochs: 40
31 changes: 17 additions & 14 deletions conf/cowc_counting.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
module:
_target_: torchgeo.trainers.RegressionTask
model: resnet18
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2

datamodule:
_target_: torchgeo.datamodules.COWCCountingDataModule
root: "data/cowc_counting"
batch_size: 64
num_workers: 4

trainer:
_target_: lightning.pytorch.Trainer
accelerator: gpu
devices: 1
min_epochs: 15
experiment:
task: cowc_counting
name: cowc_counting_test
module:
model: resnet18
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
root: "data/cowc_counting"
batch_size: 64
num_workers: 4
max_epochs: 40
31 changes: 17 additions & 14 deletions conf/cyclone.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
module:
_target_: torchgeo.trainers.RegressionTask
model: "resnet18"
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2

datamodule:
_target_: torchgeo.datamodules.TropicalCycloneDataModule
root: "data/cyclone"
batch_size: 32
num_workers: 4

trainer:
_target_: lightning.pytorch.Trainer
accelerator: gpu
devices: 1
min_epochs: 15
experiment:
task: "cyclone"
name: "cyclone_test"
module:
model: "resnet18"
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
root: "data/cyclone"
batch_size: 32
num_workers: 4
max_epochs: 40
48 changes: 28 additions & 20 deletions conf/deepglobelandcover.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
experiment:
task: "deepglobelandcover"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 7
num_filters: 1
ignore_index: null
datamodule:
root: "data/deepglobelandcover"
batch_size: 1
patch_size: 64
val_split_pct: 0.5
num_workers: 0
module:
_target_: torchgeo.trainers.SemanticSegmentationTask
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 7
num_filters: 1
ignore_index: null

datamodule:
_target_: torchgeo.datamodules.DeepGlobeLandCoverDataModule
root: "data/deepglobelandcover"
batch_size: 1
patch_size: 64
val_split_pct: 0.5
num_workers: 0

trainer:
_target_: lightning.pytorch.Trainer
accelerator: gpu
devices: 1
min_epochs: 15
max_epochs: 40
2 changes: 1 addition & 1 deletion conf/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ program: # These are the arguments that define how the train.py script works
output_dir: output
data_dir: data
log_dir: logs
overwrite: False
overwrite: False
40 changes: 24 additions & 16 deletions conf/etci2021.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
experiment:
task: "etci2021"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 6
num_classes: 2
ignore_index: 0
datamodule:
root: "data/etci2021"
batch_size: 32
num_workers: 4
module:
_target_: torchgeo.trainers.SemanticSegmentationTask
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 6
num_classes: 2
ignore_index: 0

datamodule:
_target_: torchgeo.datamodules.ETCI2021DataModule
root: "data/etci2021"
batch_size: 32
num_workers: 4

trainer:
_target_: lightning.pytorch.Trainer
accelerator: gpu
devices: 1
min_epochs: 15
max_epochs: 40
36 changes: 22 additions & 14 deletions conf/eurosat.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
experiment:
task: "eurosat"
module:
loss: "ce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
in_channels: 13
num_classes: 10
datamodule:
root: "data/eurosat"
batch_size: 128
num_workers: 4
module:
_target_: torchgeo.trainers.ClassificationTask
loss: "ce"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
in_channels: 13
num_classes: 10

datamodule:
_target_: torchgeo.datamodules.EuroSATDataModule
root: "data/eurosat"
batch_size: 128
num_workers: 4

trainer:
_target_: lightning.pytorch.Trainer
accelerator: gpu
devices: 1
min_epochs: 15
max_epochs: 40

0 comments on commit d3c82a5

Please sign in to comment.