# PyTorch Lightning

Converting our ImageNet training with Nvidia DALI to PyTorch Lightning helps us organize the code
and allows us to reuse nice modularized components offered by the community.

In [1]:
%reload_ext autoreload
%autoreload 2

import os
from glob import glob
from simple_cnn import SimpleCNN

import torch
from torch import nn, optim
import torch.nn.functional as F

from torchinfo import summary
import pytorch_lightning as pl

Wed Nov 24 09:39:41 2021: [unset]:_pmi_alps_init:alps_get_placement_info returned with error -1
Wed Nov 24 09:39:41 2021: [unset]:_pmi_init:_pmi_alps_init returned -1


## Instantiate the Lightning model

The LightningModule at [train_imagenet_pl.py](./train_imagenet_pl.py) also define the DALI dataloaders.

In [2]:
from train_imagenet_pl import LtngModel

model = LtngModel(
    data_path='/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k',
    arch='resnet18',
    optimizer='AdamW',
    batch_size=32,
    learning_rate=2e-4,
    epochs=2,
)

summary(model, input_size=(model.hparams.batch_size, 3, model.hparams.image_size, model.hparams.image_size))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Layer (type:depth-idx)                        Output Shape              Param #
LtngModel                                     --                        --
├─ResNet: 1-1                                 [256, 1000]               --
│    └─Conv2d: 2-1                            [256, 64, 112, 112]       9,408
│    └─BatchNorm2d: 2-2                       [256, 64, 112, 112]       128
│    └─ReLU: 2-3                              [256, 64, 112, 112]       --
│    └─MaxPool2d: 2-4                         [256, 64, 56, 56]         --
│    └─Sequential: 2-5                        [256, 64, 56, 56]         --
│    │    └─BasicBlock: 3-1                   [256, 64, 56, 56]         73,984
│    │    └─BasicBlock: 3-2                   [256, 64, 56, 56]         73,984
│    └─Sequential: 2-6                        [256, 128, 28, 28]        --
│    │    └─BasicBlock: 3-3                   [256, 128, 28, 28]        230,144
│    │    └─BasicBlock: 3-4                   [256, 128, 28, 28]        295,42

## Create Trainer and fit it

In [3]:
trainer = pl.Trainer(
    gpus=1,
    default_root_dir=os.path.join(os.environ['SCRATCH'], 'lightning_run'),
    limit_train_batches=50,
    limit_val_batches=50,
    max_epochs=model.hparams.epochs,
    replace_sampler_ddp=False,  # disable sampler as DALI shards the data itself
)

trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

  | Name | Type   | Params
--------------------------------
0 | net  | ResNet | 11.7 M
--------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.758    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

## DistributedDataParallel in PyTorch Lightning¶

```python
trainer = Trainer(gpus=1, num_nodes=2, strategy="ddp", ...)
```

Supports SLURM by default, so there is no need to setup any environment variables

Note: This will not work in the Jupyter notebook, please see use `sbatch` [train_imagenet_pl.sh](./train_imagenet_pl.sh)

<br>

## DeepSpeedPlugin for PyTorch Lightning

https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#deepspeed

Available in PyTorch Lightning through a strategy plugin

```python
trainer = Trainer(gpus=1, strategy="deepspeed_stage_2", precision=16, ...)
```

Or use the [plugin](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.plugins.training_type.DeepSpeedPlugin.html) directly for additional options:

```python
from pytorch_lightning.plugins import DeepSpeedPlugin

ds_plugin = DeepSpeedPlugin(
    stage=3,
    offload_optimizer=True,
    offload_parameters=True,
    remote_device="nvme",
    offload_params_device="nvme",
    offload_optimizer_device="nvme",
    nvme_path="/local_nvme",
)

trainer = Trainer(gpus=4, strategy=ds_plugin, precision=16, ...)
``` 