# Using Pytorch Lightning


[PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/starter/introduction.html) is a thin scaffolding for abstracting away some of the boilerplate of ML training in PyTorch (It also helps with distributed training, reduced precision models, and more). It does so in a way that permits us to focus on the key ingredients while retaining the flexibility and power of the Torch library.

### 🛠️ Installation and set-up

In [1]:
# Uncomment this and run inside colab
#!pip install -q pytorch-lightning

## Setting up the dataloader

We'll use the MNIST dataset, with default pytorch dataloader, and some image pre-processing hooked in from [torch transforms v2](https://pytorch.org/vision/stable/transforms.html#v1-or-v2-which-one-should-i-use)

In [2]:
import time
import torch
# We load the MNIST dataset (numbers) from torchvision
from torchvision.datasets import MNIST
# Some helpful tools are in ``transforms``
from torchvision.transforms import v2
# Get the DataLoader
from torch.utils.data import DataLoader, random_split

transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize((0.1307,), (0.3081,))])

dataset = MNIST(root="./MNIST", download=True, transform=transform)
training_set, validation_set = random_split(dataset, [55000, 5000])

In [3]:
training_loader = DataLoader(training_set, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=64)

## Setting up the Model in the Scaffold

* Torch: As before, we load the ``CrossEntropyLoss`` for this model, as we are performing a classification task
* Torch: We use ADAM as the optimizer for training the model
* Model: We wish to classify digits (0-9, hence 10 classes), doing so with a Feed Forward network/Multi-Layer Perceptron (MLP). We'll ingest a 28 * 28 image, and pass it through 2 hidden layers with ``n_layer_1`` and ``n_layer_2`` nodes

In [4]:
from torch.nn import Linear, CrossEntropyLoss, Sequential, ReLU, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningModule

class MNIST_LitModule(LightningModule):

    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3, compile=False):
        '''method used to define our model parameters'''
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        # self.layer_1 = Linear(28 * 28, n_layer_1)
        # self.layer_2 = Linear(n_layer_1, n_layer_2)
        # self.layer_3 = Linear(n_layer_2, n_classes)

        # use Sequence to define the model
        self.seq = Sequential(
            Linear(28 * 28, n_layer_1),
            ReLU(),
            Linear(n_layer_1, n_layer_2),
            ReLU(),
            Linear(n_layer_2, n_classes)
        )
        if compile:
            self.seq = torch.compile(self.seq)
            

        # loss
        self.loss = CrossEntropyLoss()

        # optimizer parameters
        self.lr = lr

    def forward(self, x):
        '''method used for inference input -> output'''

        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        x = self.seq(x)
        
        # As an alternative
        # one could do the calls in sequence here:
        # x = self.layer_1(x)
        # x = F.relu(x)
        # x = self.layer_2(x)
        # x = F.relu(x)
        # x = self.layer_3(x)

        return x

    def training_step(self, batch, batch_idx):
        '''needs to return a loss from a single batch'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)

        return loss

    def validation_step(self, batch, batch_idx):
        '''used for logging metrics'''
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)

        # Let's return preds to use it in a custom callback
        return preds

    def test_step(self, batch, batch_idx):
        '''used for logging metrics'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)

    def configure_optimizers(self):
        '''defines model optimizer'''
        return Adam(self.parameters(), lr=self.lr)

    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y, 'multiclass', num_classes=10)
        return preds, loss, acc

Lets instantiate a model from this class

In [5]:
model = MNIST_LitModule(n_layer_1=128, n_layer_2=128)

## Model Checkpoints

We'd like to keep track of not only configuration parameters and accuracy or loss, but checkpoints of the model itself

In [6]:
from pytorch_lightning.callbacks import ModelCheckpoint

# the monitored quantity should be one which is 'logged' above
checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')

## Logging

We'd like to keep track of other things. 

As a stand-in for the wandb logger, we'll use a default TensorBoardLogger, requiring no server connection

In [7]:
from pytorch_lightning.loggers import TensorBoardLogger

tb_logger = TensorBoardLogger("tb_logs", name="mnist_lightning")

## Progress Bars

The default TQDM progress bar isn't working great in this environment. Let's use a prettier one

In [8]:
from pytorch_lightning.callbacks import RichProgressBar

## Train The Model

We'll set up a trainer, tb_logger insantiated above as the logger, specify the checkpoint_callback function, and the progress bar callback, and choose the accelerator + max_epochs to train over.

In [9]:
import time

In [10]:
from pytorch_lightning import Trainer               # This class will allow us to abstract away those train/val loops
trainer = Trainer(logger=tb_logger,                 # hook to pass any of multiple loggers
                  callbacks=[RichProgressBar(),     # pretty progress bars
                             checkpoint_callback],  # our model checkpoint callback
                  accelerator="cpu",                # use cpu, gpu, or ddp_notebook
                  max_epochs=5,                     # number of epochs
                 )

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/nmangane/mambaforge/envs/cofi-2023-ext/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [11]:
start = time.perf_counter()
trainer.fit(model, training_loader, validation_loader)
print(time.perf_counter() - start, "to finish fit call")

`Trainer.fit` stopped: `max_epochs=5` reached.


46.67555041704327 to finish fit call


## Switching backend

Small model won't see much improvement (on the apple silicon mps backend, it's slower currently)

In [12]:
# Lets instantiate another model with the same characteristics
# We need a new instantiation of the trainer, including a new logger instance, but we'll keep the same callbacks
model2 = MNIST_LitModule(n_layer_1=128, n_layer_2=128)
trainer2 = Trainer(logger=TensorBoardLogger("tb_logs", name="mnist_lightning_accelerated"),
                   callbacks=[RichProgressBar(),     # pretty progress bars
                              checkpoint_callback],  # our model checkpoint callback
                   accelerator="gpu",                # use cpu, gpu, or ddp_notebook
                   max_epochs=5,                     # number of epochs
                  )

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
start = time.perf_counter()
trainer2.fit(model2, training_loader, validation_loader)
print(time.perf_counter() - start, "to finish fit call")

`Trainer.fit` stopped: `max_epochs=5` reached.


81.56411016604397 to finish fit call


## Saving and Loading Models

We can save both the model and the weights together, or just the weights (presuming we can instantiate the model from our code)

In [14]:
# Save just the weights
torch.save(model2.state_dict(), "COFI-2.1-model_weights")

In [15]:
# Load the model back
model2rt = MNIST_LitModule(n_layer_1=128, n_layer_2=128)
model2rt.load_state_dict(torch.load("COFI-2.1-model_weights"))

<All keys matched successfully>

In [25]:
model2.eval()
model2rt.eval()
inp = torch.rand(100,1,28,28)
a = model2(inp)
b = model2rt(inp)
torch.all(torch.isclose(a, b))

tensor(True)

In [20]:
# Save the model and weights together
torch.save(model2, "model2.pth")

In [22]:
model2rt2 = torch.load("model2.pth")

In [24]:
model2rt2.eval()
c = model2rt2(inp)
torch.all(torch.isclose(a, c))

tensor(True)

## PyTorch Lightning Checkpoints
See the documentation here for Lightning's checkpointing features, which allow you to resume interrupted training (as checkpoints store not only the model.state_dict(), but also the state of the trainer, optimizer, hyperparameters, etc.)

[https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html](https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html)