Skip to content

jungin500/lightning-template

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Easy-to-use PyTorch Lightning CNN model training template with MobileNetV2 and ImageNet dataloader example

Usage (conda recommended)

# Change "torch-trainer" to your new environment name
export CONDA_ENV=torch-trainer

# Create new environment along with PyTorch (with GPU support compatible up to RTX 3xxx series)
conda create -n $CONDA_ENV -c pytorch python=3.9 pytorch torchvision cudatoolkit=11.3
conda activate $CONDA_ENV
pip3 install -r requirements.txt

# ... Edit model, configuration

# Go ahead training your model!
python3 train.py wandb.login_key=123abc model.type=large ...

Details

  • PyTorch Lightning for easy-to-use NVIDIA AMP(Automatic Mixed Precision) and Tensorflow-like callback functions (pytorch_lightning.callbacks.ModelCheckpoint, pytorch_lightning.callbacks.LearningRateMonitor, so on)
  • Wandb for hyperparameters, loss, accuracy, model parameters and GPU usage logging
  • Hydra for configuration management
  • Loguru for simple and powerful logging
  • DALI for state-of-the-art dataloading framework including NVJPEG (GPU-based JPEG decoding) and on-GPU data augmentation

Example pipeline

  • Model: torchvision.models.mobilenet Large/Small
  • Dataset: ImageNet with torchvision dataset OR via NVIDIA DALI (See model.py and config.yaml)
  • Loss/Criterion: torch.nn.CrossEntropyLoss
  • Optimizers: torch.optim.RMSprop (See config.yaml)

TODO

  • Customize loss function by yaml config parameters
  • Customize multiple schedulers using yaml config

Done so far

  • Integrate NVIDIA DALI pipeline (Linux only! enable via dataloader.use_dali=True)
  • Moved training artifacts into single directory
  • Remove sample dataloader and integrate into model.py (inside LightningModule)
  • Support for Multi-GPU training pipeline (as well as Kubernetes)
  • Easy Wandb login by specifying API key into config.yaml (Be careful not to expose API key!)