Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor train.py #1237

Merged
merged 42 commits into from
Apr 23, 2023
Merged

Refactor train.py #1237

merged 42 commits into from
Apr 23, 2023

Conversation

isaaccorley
Copy link
Collaborator

@isaaccorley isaaccorley commented Apr 12, 2023

This PR refactors train.py to use hydra.utils.instantiate to define the trainer and datamodule inside the config file without having to do any config file magic. This allows a user to define a config like below:

trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  devices: 1
  min_epochs: 15
  max_epochs: 40

module:
  _target_: torchgeo.trainers.ClassificationTask
  loss: "ce"
  model: "resnet18"
  learning_rate: 1e-3
  learning_rate_schedule_patience: 6
  weights: null
  in_channels: 3
  num_classes: 45

datamodule:
  _target_: torchgeo.datamodules.RESISC45DataModule
  root: "data/resisc45"
  batch_size: 128
  num_workers: 4

@isaaccorley isaaccorley self-assigned this Apr 12, 2023
@isaaccorley isaaccorley requested review from nilsleh, calebrob6 and adamjstewart and removed request for calebrob6 April 12, 2023 21:39
@github-actions github-actions bot added the scripts Training and evaluation scripts label Apr 12, 2023
@adamjstewart adamjstewart added this to In progress in SSL4EO-L via automation Apr 12, 2023
@adamjstewart
Copy link
Collaborator

We talked about this in the meeting today, but we're probably going to try to combine pretrain/train/evaluate/predict into a single main.py that can handle all of the above, and store things like datamodule and trainer classes in the YAML config file.

@adamjstewart
Copy link
Collaborator

Actually, we could name it torchgeo/__main__.py so that users can run python -m torchgeo --train --datamodule=... --task=.... This might be sufficient to close #228.

@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Apr 15, 2023
torchgeo/trainers/simclr.py Outdated Show resolved Hide resolved
train.py Outdated Show resolved Hide resolved
@github-actions github-actions bot removed the trainers PyTorch Lightning trainers label Apr 16, 2023
@isaaccorley isaaccorley changed the title SSL pretraining Refactor train.py Apr 17, 2023
@calebrob6
Copy link
Member

I guess we can get rid of evaluate now that we have testing in train.py

tests/conf/trainer.yaml Outdated Show resolved Hide resolved
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like our existing code, there are a ton of discrepancies as to whether we should call a LightningModule a module, model, task, or trainer. I won't comment on those since we never came to a conclusion on #996.

Config files look amazing! Now just need to decide how closely we can align our training script to LightningCLI.

conf/bigearthnet.yaml Outdated Show resolved Hide resolved
tests/test_trainer.py Outdated Show resolved Hide resolved
requirements/tests.txt Outdated Show resolved Hide resolved
tests/trainers/test_byol.py Outdated Show resolved Hide resolved
train.py Show resolved Hide resolved
setup.cfg Outdated Show resolved Hide resolved
@adamjstewart
Copy link
Collaborator

I think everything looks good to me. Just want to further minimize the hydra version to the lowest version that gets our tests to pass.

@calebrob6
Copy link
Member

calebrob6 commented Apr 23, 2023

Pausing a bit to think here. Just so I understand correctly -- the reason we're bringing in hydra as a dependency is for instantiate? And this is to 1.) solve the problem of needing to maintain an explicit list of "name" --> (class, task) in train.py which 2.) makes train.py a lot more flexible?

Expanding a bit, this just turns each yaml config file into a piece of code that does:

task = SomeTask(**kwargs)
datamodule = SomeDataModule(**kwargs)
trainer = pl.Trainer(
   some_sensibly_hardcoded_stuff_like_tensorboard,
   an_easy_way_to_remember_how_to_use_gpus,
   **kwargs

which then get used in trainer.fit(...) and trainer.test(...)

Note: if so, I'm okay with this (I'm a fan of https://github.com/ashleve/lightning-hydra-template), but this just seems like a good place to think about what we're doing here.

@adamjstewart
Copy link
Collaborator

Correct. This also lets users create their own datamodule and/or trainer and use it with train.py without having to edit train.py. We're planning on moving train.py to torchgeo/__main__.py so that it gets installed and users can use python -m torchgeo fit config_file=.... This will become even more important when we do that.

@calebrob6
Copy link
Member

Okay cool! I don't see why train.py needs to be moved to torchgeo/__main__.py actually, but I think that's a discussion for another day.

@calebrob6 calebrob6 merged commit d3c82a5 into microsoft:main Apr 23, 2023
SSL4EO-L automation moved this from In progress to Done Apr 23, 2023
@adamjstewart
Copy link
Collaborator

It relates to #228, just provides an easy way for people to pip install torchgeo and user our train.py script.

@adriantre
Copy link
Contributor

adriantre commented Jun 8, 2023

Just to mention, Pytorch Lightning CLI is possibly solving the same thing.

python main.py fit --config /config.yaml 
                   --model MyModel 
                   --data MyGeoDataModule 
                   --any_class_param_to_override

where fit can be replaced with test, predict etc.

@adamjstewart
Copy link
Collaborator

Yep, I'm planning on integrating this and adding a torchgeo command in the next release.

@adamjstewart adamjstewart added this to the 0.5.0 milestone Sep 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dependencies Packaging and dependencies scripts Training and evaluation scripts testing Continuous integration testing
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

None yet

4 participants