This is to-go pytorch template utilizing lightning and wandb.
This template uses Lightning CLI for config management.
It follows most of Lightning CLI docs but, integrated with wandb.
Since Lightning CLI instantiate classes on-the-go, there were some work-around while integrating WandbLogger to the template.
This might not be the best practice, but still it works and quite convinient.
It uses Lightning CLI, so most of its usage can be found at its official docs.
There are some added arguments related to wandb.
--nameor-n: Name of the run, displayed inwandb--versionor-v: Version of the run, displayed inwandbas tags
Basic cmdline usage is as follows.
We assume cwd is project root dir.
python src/main.py fit -c configs/config.yaml -n debug-fit-run -v debug-versionIf using wandb for logging, change "project" key in cli_module/rich_wandb.py
If you want to access log directory in your LightningModule, you can access as follows.
log_root_dir = self.logger.log_dir or self.logger.save_dirIf using wandb for logging, model ckpt files are uploaded to wandb.
Since the size of ckpt files are too large, clean-up process needed.
Clean-up process delete all model ckpt artifacts without any aliases (e.g. best, lastest)
To toggle off the clean-up process, add the following to config.yaml. Then every version of model ckpt files will be saved to wandb.
trainer:
logger:
init_args:
clean: falseOne can save model checkpoints using Lightning Callbacks.
It contains model weight, and other state_dict for resuming train.
There are several ways to save ckpt files at either local or cloud.
-
Just leave everything in default, ckpt files will be saved locally. (at
logs/${name}/${version}/fit/checkpoints) -
If you want to save ckpt files as
wandbArtifacts, add the following config. (The ckpt files will be saved locally too.)
trainer:
logger:
init_args:
log_model: all- If you want to save ckpt files in cloud rather than local, you can change the save path by adding the config. (The ckpt files will NOT be saved locally.)
model_ckpt:
dirpath: gs://bucket_name/path/for/checkpointsYou can set async checkpoint saving by providing config as follows.
trainer:
plugins:
- AsyncCheckpointIOJust add BatchSizeFinder callbacks in the config
trainer:
callbacks:
- class_path: BatchSizeFinderOr add them in the cmdline.
python src/main.py fit -c configs/config.yaml --trainer.callbacks+=BatchSizeFinderpython src/tune.py -c configs/config.yamlNOTE: No subcommand in cmdline
Basically all logs are stored in logs/${name}/${version}/${job_type} where ${name} and ${version} are configured in yaml file or cmdline.
{job_type} can be one of fit, test, validate, etc.
python src/main.py test -c configs/config.yaml -n debug-test-run -v debug-version --ckpt_path YOUR_CKPT_PATH- Check pretrained weight loading
- Consider multiple optimizer using cases (i.e. GAN)
- Add instructions in README (on-going)
- Clean code