## Introduction to PyTorch Lighting

PyTorch Lightning is a wrapper around PyTorch that is very popular because it simplifies a lot of things:

* Removes the need for a lot of boilerplate code in the training loop, such as tracking progress, computing metrics, backpropagating and resetting the gradient, etc
* Makes it easy to access and use multiple GPUs 
* Makes it easy to save a model, check its performance, and train it for further epochs if needed

We will introduce PyTorch lightning model classes, which differ a bit from vanilla PyTorch, and show how to use it for MNIST classification (compare to car_mnist_projects.ipynb).

We will also introduce TensorBoard, a visualization tool for assessing model performance.

In [2]:
import pytorch_lightning as pl
import torch 
import torch.nn as nn 

from torchmetrics import __version__ as torchmetrics_version
from pkg_resources import parse_version

from torchmetrics import Accuracy

from IPython.display import Image

### PyTorch Lightning model definition

In [3]:
# here is an MLP implemented as a PyTorch lightning (PL) module
# PyTorch lightning recognizes several class method names; see comments below
class MultiLayerPerceptron(pl.LightningModule):
    def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16)):
        super().__init__()  # note here we inherit lightning module class, not nn.module

        # new PL attributes: Accuracy() from torchmetrics automatically computes model Accuracy
        if parse_version(torchmetrics_version) > parse_version("0.8"):
            self.train_acc = Accuracy(task="multiclass", num_classes=10)
            self.valid_acc = Accuracy(task="multiclass", num_classes=10)
            self.test_acc = Accuracy(task="multiclass", num_classes=10)
        else:
            self.train_acc = Accuracy()
            self.valid_acc = Accuracy()
            self.test_acc = Accuracy()

        # the neural network model itself will be similar to the one used in the MNIST example in car_mnist_projects.ipynb 
        input_size = image_shape[0] * image_shape[1] * image_shape[2] 
        all_layers = [nn.Flatten()]
        for hidden_unit in hidden_units: 
            layer = nn.Linear(input_size, hidden_unit) 
            all_layers.append(layer) 
            all_layers.append(nn.ReLU()) 
            input_size = hidden_unit 
 
        all_layers.append(nn.Linear(hidden_units[-1], 10)) 
        self.model = nn.Sequential(*all_layers)
    
    def forward(self, x):
        # forward pass is very simple given our defined self.model
        x = self.model(x)
        return x
    
    # PL recognizes a function called training_step that tells it what actions to perform to train on a batch of data
    # note that now training is defined inside the class, rather than as separate code
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)  # computes forward pass and returns output logits (I think?)
        loss = nn.functional.cross_entropy(logits, y)  # loss function
        preds = torch.argmax(logits, dim=1)   # make predictions
        self.train_acc.update(preds, y)  # updated accuracy
        self.log("train_loss", loss, prog_bar=True)  # log the training loss
        return loss
    
    # similarly define test_step, defining what is done in a test step on a batch
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_acc.update(preds, y)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_acc.compute(), prog_bar=True)
        return loss
    
    # again, similarly define validation step. this is only done at specific times, like the end of an epoch.
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.valid_acc.update(preds, y)
        self.log("valid_loss", loss, prog_bar=True)
        self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
        return loss
    
    # Conditionally define epoch end methods based on PyTorch Lightning version
    # The epoch end methods run at the end of each epoch (the step ones run for each batch)
    # Here we just use the Accuracy() metric from self.train_acc, etc, to compute the accuracy
    #    for the whole epoch given the accuracy values accumulated at each step.
    if parse_version(pl.__version__) >= parse_version("2.0"):
        # For PyTorch Lightning 2.0 and above
        def on_training_epoch_end(self):
            self.log("train_acc", self.train_acc.compute())
            self.train_acc.reset()

        def on_validation_epoch_end(self):
            self.log("valid_acc", self.valid_acc.compute())
            self.valid_acc.reset()

        def on_test_epoch_end(self):
            self.log("test_acc", self.test_acc.compute())
            self.test_acc.reset()

    else:
        # For PyTorch Lightning < 2.0
        def training_epoch_end(self, outs):
            self.log("train_acc", self.train_acc.compute())
            self.train_acc.reset()

        def validation_epoch_end(self, outs):
            self.log("valid_acc", self.valid_acc.compute())
            self.valid_acc.reset()

        def test_epoch_end(self, outs):
            self.log("test_acc", self.test_acc.compute())
            self.test_acc.reset()
    
    # PL recognizes the method "configure_optimizers" to define our optimizer
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer


### PyTorch DataModule

PL has a way for you to specify a class for custom data downloading and defining data loaders. Like with the training code, this makes everything more organized by defining a class rather than having free-floating code.

In [4]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split
 
from torchvision.datasets import MNIST
from torchvision import transforms

In [None]:
class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_path='../datasets/'):
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose([transforms.ToTensor()])
        
    def prepare_data(self):
        MNIST(root=self.data_path, download=True) 

    def setup(self, stage=None):
        # stage is either 'fit', 'validate', 'test', or 'predict'
        # here note relevant
        mnist_all = MNIST( 
            root=self.data_path,
            train=True,
            transform=self.transform,  
            download=False
        ) 

        self.train, self.val = random_split(
            mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1)
        )

        self.test = MNIST( 
            root=self.data_path,
            train=False,
            transform=self.transform,  
            download=False
        ) 

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=64, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=64, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=64, num_workers=4)
    


In [6]:
# now it's simple to define a data loader for the MNIST data
torch.manual_seed(1) 
mnist_dm = MnistDataModule()

### PyTorch Lightning Trainer Class

PL has a Trainer class that actually trains your model class on your data loader. It is simple to run given that you have specified the above classes, obviating a lot of biolerplate code like zeroing out gradients. It also makes it easy to specify multiple GPUs.

In [8]:
from pytorch_lightning.callbacks import ModelCheckpoint

mnistclassifier = MultiLayerPerceptron()

# callbacks are a PyTorch lightning feature that allow us to define extra code that
#    we want to execute during training. we can define any custom function here.
# the idea is to keep our core PL module code clean, with bells and whistles as callbacks.
# in our case, we will use PL's built-in ModelCheckpoint function to save the model
#    with the best validation accuracy in each epoch, in case we start to overfit and
#    it goes down at some point.
callbacks = [ModelCheckpoint(save_top_k=1, mode='max', monitor="valid_acc")] # save top 1 model

if torch.cuda.is_available():  # if you have GPUs
    trainer = pl.Trainer(max_epochs=10, callbacks=callbacks, devices=1)
else:
    trainer = pl.Trainer(max_epochs=10, callbacks=callbacks)

trainer.fit(model=mnistclassifier, datamodule=mnist_dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | train_acc | MulticlassAccuracy | 0      | train
1 | valid_acc | MulticlassAccuracy | 0      | train
2 | test_acc  | MulticlassAccuracy | 0      | train
3 | model     | Sequential         | 25.8 K | train
---------------------------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\natha\Documents\pyro_pytorch_tutorials\pyro_pytorch_venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

c:\Users\natha\Documents\pyro_pytorch_tutorials\pyro_pytorch_venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 9: 100%|██████████| 860/860 [00:45<00:00, 18.84it/s, v_num=0, train_loss=0.0912, valid_loss=0.161, valid_acc=0.953]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 860/860 [00:45<00:00, 18.84it/s, v_num=0, train_loss=0.0912, valid_loss=0.161, valid_acc=0.953]


In [9]:
# now we evaluate the model on the test data using our trainer
trainer.test(model=mnistclassifier, datamodule=mnist_dm, ckpt_path='best')

Restoring states from the checkpoint path at c:\Users\natha\Documents\pyro_pytorch_tutorials\pytorch_tb\ch13\lightning_logs\version_0\checkpoints\epoch=8-step=7740.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at c:\Users\natha\Documents\pyro_pytorch_tutorials\pytorch_tb\ch13\lightning_logs\version_0\checkpoints\epoch=8-step=7740.ckpt
c:\Users\natha\Documents\pyro_pytorch_tutorials\pyro_pytorch_venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing DataLoader 0: 100%|██████████| 157/157 [00:02<00:00, 74.39it/s]


[{'test_loss': 0.13700783252716064, 'test_acc': 0.9599000215530396}]

### Evaluating/Visualizing the model using TensorBoard

TensorBoard is a graphical software package for visualizing and analyzing model performance. It can be run from the command line and pop up in the browser, or directly in Jupyter notebooks. We'll show how to use it here.

In [None]:
# Start tensorboard
# by default PL puts performance logs in lightning_logs, so we pull performance data from there
%load_ext tensorboard
%tensorboard --logdir lightning_logs/
# NOTE: has issues with VSCode, can deal with this later
# Run in browser via command line, e.g. tensorboard --logdir .\pytorch_tb\ch13\lightning_logs\

### Resuming training from a checkpoint

In [13]:
# here we resume from a desired checkpoint and set max_epochs=15.
# this trains an additional 5 epochs since we previously did 10.
path = 'lightning_logs/version_0/checkpoints/epoch=8-step=7740.ckpt'

model = MultiLayerPerceptron.load_from_checkpoint(path)

if torch.cuda.is_available(): # if you have GPUs
    trainer = pl.Trainer(
        max_epochs=5, callbacks=callbacks, devices=1
    )
else:
    trainer = pl.Trainer(
        max_epochs=5, callbacks=callbacks
    )

trainer.fit(model=model, datamodule=mnist_dm)


# NOTE: this is the book's original code, but the new PL appears to not have resume_from_checkpoint
# so i did it my own way, looking at https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html
#if torch.cuda.is_available(): # if you have GPUs
#    trainer = pl.Trainer(
#        max_epochs=15, callbacks=callbacks, resume_from_checkpoint=path, gpus=1
#    )
#else:
#    trainer = pl.Trainer(
#        max_epochs=15, callbacks=callbacks, resume_from_checkpoint=path
#    )

#trainer.fit(model=mnistclassifier, datamodule=mnist_dm)

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\natha\Documents\pyro_pytorch_tutorials\pyro_pytorch_venv\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:654: Checkpoint directory c:\Users\natha\Documents\pyro_pytorch_tutorials\pytorch_tb\ch13\lightning_logs\version_0\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | train_acc | MulticlassAccuracy | 0      | train
1 | valid_acc | MulticlassAccuracy | 0      | train
2 | test_acc  | MulticlassAccuracy | 0      | train
3 | model     | Sequential         | 25.8 K | train
---------------------------------------------------------
25.8

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\natha\Documents\pyro_pytorch_tutorials\pyro_pytorch_venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

c:\Users\natha\Documents\pyro_pytorch_tutorials\pyro_pytorch_venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 4: 100%|██████████| 860/860 [00:47<00:00, 17.98it/s, v_num=1, train_loss=0.0392, valid_loss=0.159, valid_acc=0.954] 

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 860/860 [00:47<00:00, 17.98it/s, v_num=1, train_loss=0.0392, valid_loss=0.159, valid_acc=0.954]


In [None]:
# re-visualize in tensorboard
%tensorboard --logdir lightning_logs/

In [14]:
# re-evaluate test data
trainer.test(model=model, datamodule=mnist_dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\natha\Documents\pyro_pytorch_tutorials\pyro_pytorch_venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing DataLoader 0: 100%|██████████| 157/157 [00:01<00:00, 100.80it/s]


[{'test_loss': 0.1347116380929947, 'test_acc': 0.9603999853134155}]

In [15]:
# try using the 'best' model on the validation data (via our callback)
trainer.test(model=model, datamodule=mnist_dm, ckpt_path='best')

Restoring states from the checkpoint path at c:\Users\natha\Documents\pyro_pytorch_tutorials\pytorch_tb\ch13\lightning_logs\version_0\checkpoints\epoch=2-step=2580.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at c:\Users\natha\Documents\pyro_pytorch_tutorials\pytorch_tb\ch13\lightning_logs\version_0\checkpoints\epoch=2-step=2580.ckpt


Testing DataLoader 0: 100%|██████████| 157/157 [00:02<00:00, 70.32it/s]


[{'test_loss': 0.13509604334831238, 'test_acc': 0.960099995136261}]

In [17]:
# PL saves the model automatically for us, so we can easily load this module from a checkpoint for later re-use
path = "lightning_logs/version_0/checkpoints/epoch=2-step=2580.ckpt"
model2 = MultiLayerPerceptron.load_from_checkpoint(path)