Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add command line script #1559

Merged
merged 26 commits into from
Sep 27, 2023
Merged

Conversation

adamjstewart
Copy link
Collaborator

@adamjstewart adamjstewart commented Sep 12, 2023

This PR replaces our train.py script with a torchgeo script installed by pip. It can be invoked in two ways:

# If torchgeo has been installed
torchgeo
# If torchgeo has been installed, or if it has been cloned to the current directory
python3 -m torchgeo

It is based on LightningCLI, and supports command-line configuration or YAML/JSON config files. Valid options can be found from the help messages:

# See valid stages
torchgeo --help
# See valid trainer options
torchgeo fit --help
# See valid model options
torchgeo fit --model.help ClassificationTask
# See valid data options
torchgeo fit --data.help EuroSAT100DataModule

An example of the script in action:

# Train and validate a model
torchgeo fit --config tests/conf/eurosat100.yaml --trainer.max_epochs=1
# Validate-only
torchgeo validate --config tests/conf/eurosat100.yaml
# Calculate and report test accuracy
torchgeo test --config tests/conf/eurosat100.yaml

It can also be imported and used in a Python script if you need to extend it to add new features:

from torchgeo.main import main

main(["fit", "--config", "tests/conf/eurosat100.yaml"])

Documentation

Setuptools packaging:

LightningCLI:

Closes #228
Closes #1352

@adamjstewart adamjstewart added this to the 0.5.0 milestone Sep 12, 2023
@github-actions github-actions bot added testing Continuous integration testing dependencies Packaging and dependencies scripts Training and evaluation scripts trainers PyTorch Lightning trainers labels Sep 12, 2023
@github-actions github-actions bot added datasets Geospatial or benchmark datasets datamodules PyTorch Lightning datamodules labels Sep 14, 2023
@calebrob6
Copy link
Member

Why do we want this vs. a train.py script?

@adamjstewart
Copy link
Collaborator Author

I think there are two ways to interpret your question:

Why bother with an entry point, what's wrong with train.py?

It's important to distinguish between TorchGeo (the Python library) and TorchGeo (the GitHub repo). For developers like you and me, we primarily install TorchGeo (the GitHub repo) using:

> git clone https://github.com/microsoft/torchgeo.git

This also installs train.py, unit tests, some conf files, and experiment code. However, the majority of users will instead install TorchGeo (the Python library) using:

> pip install torchgeo  # or conda/spack

The primary advantage of using an entry point is that now train.py gets installed for all users, not just developers. Other advantages:

  • It's easier to unit test and collect coverage reports for
  • It's an official part of our stable API and can be documented in our API docs
  • It's easier to use and document in tutorials
  • It can be extended in custom scripts (from torchgeo.main import main; ...)

What's the advantage of LightningCLI over our current custom omegaconf/hydra code?

There are a ton of advantages of LightningCLI over our current implementation:

  • Runtime verification: all init_args undergo runtime type verification! Use dict_kwargs for things that can't be verified, like dataset kwargs. See Add config file schema #1352
  • Self-documenting: use torchgeo fit --help for all valid trainer options, torchgeo fit --model.help ClassificationTask for model options, and torchgeo fit --data.help EuroSAT100DataModule for data module options
  • Less code: automatically handles merging multiple config files with CLI args
  • Override optimizers: can override the default optimizer or lr scheduler from config/CLI for experimentation
  • JSON support: probably most interesting for web app integration

LightningCLI has a ton of exciting features. I would encourage everyone to read through all of https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html before deciding whether or not it is useful.

@calebrob6
Copy link
Member

calebrob6 commented Sep 22, 2023

Got it, I don't have a problem with this in principle. This is just a huge PR and I haven't wrapped my head around what you get from LightningCLI.

@adamjstewart
Copy link
Collaborator Author

If you ignore the config files and tests it's actually pretty small. But yeah, definitely takes a bit to wrap your head around. But I think this will put us more in line with other Lightning projects. The config file format will be the same for TorchGeo and every other project that uses LightningCLI.

@adamjstewart adamjstewart marked this pull request as ready for review September 23, 2023 19:05
Copy link
Member

@calebrob6 calebrob6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just played around with this and I really like it, e.g. Lightning parallelized over 4 GPUs by default, and I trained RESISC45 for 10 epochs in like 2 minutes.

Some things that aren't clear to me:

  • How do I get it to save model checkpoints somewhere?
  • How can I name the experiment run something that makes sense?
  • When running torchgeo validate --config conf/something.yaml, how do I specify that it should validate with a model checkpoint? Currently it just uses the default model. It also creates a new tensorboard log entry.

Some things that I noticed:

  • Train classification accuracy and the other metrics aren't being logged during training, validation accuracy isn't being computed during validation, etc.

I really want to add a docs page to go along with this to help get people started. Maybe we should implement a new requirement on big PRs that you have to write a little mini tutorial to go along with it?

@adamjstewart
Copy link
Collaborator Author

adamjstewart commented Sep 25, 2023

Preliminary answers before I start fixing things:

How do I get it to save model checkpoints somewhere?

You would override --trainer.default_root_dir. It defaults to the current directory.

How can I name the experiment run something that makes sense?

Still need to add the ability to customize that. It defaults to lightning_logs. It's not hard, just want to make sure I'm not missing something builtin.

how do I specify that it should validate with a model checkpoint?

You would use the ckpt_path or model.init_args.weights parameter in config or on the CLI.

Train classification accuracy and the other metrics aren't being logged during training, validation accuracy isn't being computed during validation, etc.

Will fix.

I really want to add a docs page to go along with this to help get people started. Maybe we should implement a new requirement on big PRs that you have to write a little mini tutorial to go along with it?

I've been really lazy with this because I keep telling myself to put it off until we start writing tons of tutorials. This one will be particularly hard because it's all command-line commands. But I think it will still work.

@adamjstewart
Copy link
Collaborator Author

adamjstewart commented Sep 25, 2023

Before I accidentally delete this file, I did do a survey of how other common CLI tools organize their code:

Expand
# black

# pyproject.toml
[project.scripts]
black = "black:patched_main"
blackd = "blackd:patched_main [d]"

# src/black/__main__.py
from black import patched_main

patched_main()

# src/black/__init__.py
def patched_main() -> None:
    # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
    # environments so just assume we always need to call it if frozen.
    if getattr(sys, "frozen", False):
        from multiprocessing import freeze_support

        freeze_support()

    patch_click()
    main()


if __name__ == "__main__":
    patched_main()

# flake8

# setup.cfg
[options.entry_points]
console_scripts =
        flake8 = flake8.main.cli:main

# src/flake8/__main__.py
from flake8.main.cli import main

if __name__ == "__main__":
    raise SystemExit(main())

# src/flake8/main/cli.py
from flake8.main import application


def main(argv: Sequence[str] | None = None) -> int:
    if argv is None:
        argv = sys.argv[1:]

    app = application.Application()
    app.run(argv)
    return app.exit_code()

# isort

# pyproject.toml
[tool.poetry.scripts]
isort = "isort.main:main"
isort-identify-imports = "isort.main:identify_imports_main"

# isort/__main__.py
from isort.main import main

main()

# isort/main.py
def main():
    pass

if __name__ == "__main__":
    main()

# pytest

# setup.cfg
[options.entry_points]
console_scripts =
        pytest=pytest:console_main
        py.test=pytest:console_main

# src/pytest/__main__.py
import pytest

if __name__ == "__main__":
    raise SystemExit(pytest.console_main())

# src/_pytest/config/__init__.py
def main():
    ...

def console_main():
    try:
        main()
    except:
        ...

# pydocstyle

# pyproject.toml
[tool.poetry.scripts]
pydocstyle = "pydocstyle.cli:main"

# src/pydocstyle/__main__.py
def main() -> None:
    from pydocstyle import cli

    cli.main()

if __name__ == '__main__':
    main()

# src/pydocstyle/cli.py
def main():
    """Run pydocstyle as a script."""
    try:
        sys.exit(run_pydocstyle())
    except KeyboardInterrupt:
        pass

# pyupgrade

# setup.cfg
[options.entry_points]
console_scripts =
        pyupgrade = pyupgrade._main:main

# pyupgrade/__main__.py
from pyupgrade._main import main

if __name__ == '__main__':
    raise SystemExit(main())

# pyupgrade/_main.py
def main():
    ...

if __name__ == '__main__':
    raise SystemExit(main())

# mypy

# setup.py
    entry_points={
        "console_scripts": [
            "mypy=mypy.__main__:console_entry",
            "stubgen=mypy.stubgen:main",
            "stubtest=mypy.stubtest:main",
            "dmypy=mypy.dmypy.client:console_entry",
            "mypyc=mypyc.__main__:main",
        ]
    },

# mypy/__main__.py
def console_entry() -> None:
    try:
        main()
        sys.stdout.flush()
        sys.stderr.flush()
    except BrokenPipeError:
        # Python flushes standard streams on exit; redirect remaining output
        # to devnull to avoid another BrokenPipeError at shutdown
        devnull = os.open(os.devnull, os.O_WRONLY)
        os.dup2(devnull, sys.stdout.fileno())
        sys.exit(2)
    except KeyboardInterrupt:
        _, options = process_options(args=sys.argv[1:])
        if options.show_traceback:
            sys.stdout.write(traceback.format_exc())
        formatter = FancyFormatter(sys.stdout, sys.stderr, False)
        msg = "Interrupted\n"
        sys.stdout.write(formatter.style(msg, color="red", bold=True))
        sys.stdout.flush()
        sys.stderr.flush()
        sys.exit(2)


if __name__ == "__main__":
    console_entry()

# pip

# setup.py
    entry_points={
        "console_scripts": [
            "pip=pip._internal.cli.main:main",
            "pip{}=pip._internal.cli.main:main".format(sys.version_info[0]),
            "pip{}.{}=pip._internal.cli.main:main".format(*sys.version_info[:2]),
        ],
    },

# src/pip/__main__.py
if __name__ == "__main__":
    from pip._internal.cli.main import main as _main

    sys.exit(_main())

# src/pip/_internal/cli/main.py
def main():
    ...

# build

# pyproject.toml
[project.scripts]
pyproject-build = "build.__main__:entrypoint"

# src/build/__main__.py
def main(args):
    ...

def entrypoint() -> None:
    main(sys.argv[1:])

if __name__ == '__main__':
    main(sys.argv[1:], 'python -m build')

We don't have to use torchgeo.main.main, that's just a design choice.

@adamjstewart
Copy link
Collaborator Author

Train classification accuracy and the other metrics aren't being logged during training, validation accuracy isn't being computed during validation, etc.

@calebrob6 can you see if this is fixed by the last commit? I can also play around with on_step/on_epoch to control frequency and prog_bar to control verbosity.

@adamjstewart
Copy link
Collaborator Author

Hoping for a quick response to Lightning-AI/pytorch-lightning#18641, otherwise I'll just ignore the mypy warnings.

@calebrob6
Copy link
Member

Yep, metrics are logged now, however they are logged at strange intervals (see x-axis of graphs below) and I get this warning in the multi-gpu case:

PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.

image

@calebrob6
Copy link
Member

calebrob6 commented Sep 26, 2023

By default, lightning should log every 50 training steps (https://lightning.ai/docs/pytorch/stable/extensions/logging.html#logging-frequency), which is why the first train point is near 50, but it doesn't explain why the frequency of train logging increases after that.

Update, for whatever reason, when the trainer gets the early stopping signal (I see this in the logs "Trainer was signaled to stop but the required min_epochs=15 or min_steps=None has not been met. Training will continue...") the log frequency gets set to 1.

@adamjstewart
Copy link
Collaborator Author

Type errors will be fixed by Lightning-AI/pytorch-lightning#18646

@adamjstewart
Copy link
Collaborator Author

adamjstewart commented Sep 27, 2023

@calebrob6 I'm inclined to call this PR "good enough" for a first draft and get it merged so we can focus on other last minute features. We can add docs and add new features or fix bugs in future releases. What do you think?

@calebrob6
Copy link
Member

calebrob6 commented Sep 27, 2023

First thought: Are you feeling okay!?
Second thought: We do need to figure out how to save model checkpoints and pick where the outputs go. Currently running torchgeo fit --config ... will create a directory called ./lightning_logs/ that just holds the tensorboard logs. If we release this in 0.5 but then figure out we need to add something to main.py to get the checkpoints to save, then we might confuse a lot of users.

@calebrob6
Copy link
Member

(otherwise I don't care if the metric logging is working perfectly)

@adamjstewart
Copy link
Collaborator Author

We do need to figure out how to save model checkpoints

They are automatically saved in lightning_logs/version_#/checkpoints

and pick where the outputs go.

The root directory can be controlled using --trainer.default_root_dir my_logs or:

trainer:
  default_root_dir: my_logs

@calebrob6
Copy link
Member

Problem solved!

@calebrob6 calebrob6 merged commit 984e222 into microsoft:main Sep 27, 2023
37 checks passed
@adamjstewart adamjstewart deleted the scripts/lightningcli branch December 19, 2023 18:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules dependencies Packaging and dependencies scripts Training and evaluation scripts testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add config file schema Add train entrypoint
2 participants