-
Notifications
You must be signed in to change notification settings - Fork 347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add command line script #1559
Add command line script #1559
Conversation
0714ab5
to
96b85b0
Compare
Why do we want this vs. a |
I think there are two ways to interpret your question: Why bother with an entry point, what's wrong with train.py?It's important to distinguish between TorchGeo (the Python library) and TorchGeo (the GitHub repo). For developers like you and me, we primarily install TorchGeo (the GitHub repo) using: > git clone https://github.com/microsoft/torchgeo.git This also installs train.py, unit tests, some conf files, and experiment code. However, the majority of users will instead install TorchGeo (the Python library) using: > pip install torchgeo # or conda/spack The primary advantage of using an entry point is that now train.py gets installed for all users, not just developers. Other advantages:
What's the advantage of LightningCLI over our current custom omegaconf/hydra code?There are a ton of advantages of LightningCLI over our current implementation:
LightningCLI has a ton of exciting features. I would encourage everyone to read through all of https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html before deciding whether or not it is useful. |
fe9b06c
to
fdeeae9
Compare
6c80d70
to
f95c7d1
Compare
Got it, I don't have a problem with this in principle. This is just a huge PR and I haven't wrapped my head around what you get from LightningCLI. |
If you ignore the config files and tests it's actually pretty small. But yeah, definitely takes a bit to wrap your head around. But I think this will put us more in line with other Lightning projects. The config file format will be the same for TorchGeo and every other project that uses LightningCLI. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just played around with this and I really like it, e.g. Lightning parallelized over 4 GPUs by default, and I trained RESISC45 for 10 epochs in like 2 minutes.
Some things that aren't clear to me:
- How do I get it to save model checkpoints somewhere?
- How can I name the experiment run something that makes sense?
- When running torchgeo validate --config conf/something.yaml, how do I specify that it should validate with a model checkpoint? Currently it just uses the default model. It also creates a new tensorboard log entry.
Some things that I noticed:
- Train classification accuracy and the other metrics aren't being logged during training, validation accuracy isn't being computed during validation, etc.
I really want to add a docs page to go along with this to help get people started. Maybe we should implement a new requirement on big PRs that you have to write a little mini tutorial to go along with it?
Preliminary answers before I start fixing things:
You would override
Still need to add the ability to customize that. It defaults to
You would use the
Will fix.
I've been really lazy with this because I keep telling myself to put it off until we start writing tons of tutorials. This one will be particularly hard because it's all command-line commands. But I think it will still work. |
Before I accidentally delete this file, I did do a survey of how other common CLI tools organize their code: Expand# black
# pyproject.toml
[project.scripts]
black = "black:patched_main"
blackd = "blackd:patched_main [d]"
# src/black/__main__.py
from black import patched_main
patched_main()
# src/black/__init__.py
def patched_main() -> None:
# PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
# environments so just assume we always need to call it if frozen.
if getattr(sys, "frozen", False):
from multiprocessing import freeze_support
freeze_support()
patch_click()
main()
if __name__ == "__main__":
patched_main()
# flake8
# setup.cfg
[options.entry_points]
console_scripts =
flake8 = flake8.main.cli:main
# src/flake8/__main__.py
from flake8.main.cli import main
if __name__ == "__main__":
raise SystemExit(main())
# src/flake8/main/cli.py
from flake8.main import application
def main(argv: Sequence[str] | None = None) -> int:
if argv is None:
argv = sys.argv[1:]
app = application.Application()
app.run(argv)
return app.exit_code()
# isort
# pyproject.toml
[tool.poetry.scripts]
isort = "isort.main:main"
isort-identify-imports = "isort.main:identify_imports_main"
# isort/__main__.py
from isort.main import main
main()
# isort/main.py
def main():
pass
if __name__ == "__main__":
main()
# pytest
# setup.cfg
[options.entry_points]
console_scripts =
pytest=pytest:console_main
py.test=pytest:console_main
# src/pytest/__main__.py
import pytest
if __name__ == "__main__":
raise SystemExit(pytest.console_main())
# src/_pytest/config/__init__.py
def main():
...
def console_main():
try:
main()
except:
...
# pydocstyle
# pyproject.toml
[tool.poetry.scripts]
pydocstyle = "pydocstyle.cli:main"
# src/pydocstyle/__main__.py
def main() -> None:
from pydocstyle import cli
cli.main()
if __name__ == '__main__':
main()
# src/pydocstyle/cli.py
def main():
"""Run pydocstyle as a script."""
try:
sys.exit(run_pydocstyle())
except KeyboardInterrupt:
pass
# pyupgrade
# setup.cfg
[options.entry_points]
console_scripts =
pyupgrade = pyupgrade._main:main
# pyupgrade/__main__.py
from pyupgrade._main import main
if __name__ == '__main__':
raise SystemExit(main())
# pyupgrade/_main.py
def main():
...
if __name__ == '__main__':
raise SystemExit(main())
# mypy
# setup.py
entry_points={
"console_scripts": [
"mypy=mypy.__main__:console_entry",
"stubgen=mypy.stubgen:main",
"stubtest=mypy.stubtest:main",
"dmypy=mypy.dmypy.client:console_entry",
"mypyc=mypyc.__main__:main",
]
},
# mypy/__main__.py
def console_entry() -> None:
try:
main()
sys.stdout.flush()
sys.stderr.flush()
except BrokenPipeError:
# Python flushes standard streams on exit; redirect remaining output
# to devnull to avoid another BrokenPipeError at shutdown
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, sys.stdout.fileno())
sys.exit(2)
except KeyboardInterrupt:
_, options = process_options(args=sys.argv[1:])
if options.show_traceback:
sys.stdout.write(traceback.format_exc())
formatter = FancyFormatter(sys.stdout, sys.stderr, False)
msg = "Interrupted\n"
sys.stdout.write(formatter.style(msg, color="red", bold=True))
sys.stdout.flush()
sys.stderr.flush()
sys.exit(2)
if __name__ == "__main__":
console_entry()
# pip
# setup.py
entry_points={
"console_scripts": [
"pip=pip._internal.cli.main:main",
"pip{}=pip._internal.cli.main:main".format(sys.version_info[0]),
"pip{}.{}=pip._internal.cli.main:main".format(*sys.version_info[:2]),
],
},
# src/pip/__main__.py
if __name__ == "__main__":
from pip._internal.cli.main import main as _main
sys.exit(_main())
# src/pip/_internal/cli/main.py
def main():
...
# build
# pyproject.toml
[project.scripts]
pyproject-build = "build.__main__:entrypoint"
# src/build/__main__.py
def main(args):
...
def entrypoint() -> None:
main(sys.argv[1:])
if __name__ == '__main__':
main(sys.argv[1:], 'python -m build') We don't have to use |
@calebrob6 can you see if this is fixed by the last commit? I can also play around with on_step/on_epoch to control frequency and prog_bar to control verbosity. |
Hoping for a quick response to Lightning-AI/pytorch-lightning#18641, otherwise I'll just ignore the mypy warnings. |
Yep, metrics are logged now, however they are logged at strange intervals (see x-axis of graphs below) and I get this warning in the multi-gpu case:
|
By default, lightning should log every 50 training steps (https://lightning.ai/docs/pytorch/stable/extensions/logging.html#logging-frequency), which is why the first train point is near 50, but it doesn't explain why the frequency of train logging increases after that. Update, for whatever reason, when the trainer gets the early stopping signal (I see this in the logs "Trainer was signaled to stop but the required |
Type errors will be fixed by Lightning-AI/pytorch-lightning#18646 |
@calebrob6 I'm inclined to call this PR "good enough" for a first draft and get it merged so we can focus on other last minute features. We can add docs and add new features or fix bugs in future releases. What do you think? |
First thought: Are you feeling okay!? |
(otherwise I don't care if the metric logging is working perfectly) |
They are automatically saved in
The root directory can be controlled using trainer:
default_root_dir: my_logs |
Problem solved! |
This PR replaces our
train.py
script with atorchgeo
script installed by pip. It can be invoked in two ways:It is based on LightningCLI, and supports command-line configuration or YAML/JSON config files. Valid options can be found from the help messages:
An example of the script in action:
It can also be imported and used in a Python script if you need to extend it to add new features:
Documentation
Setuptools packaging:
LightningCLI:
Closes #228
Closes #1352