# Pytorch Lightning

PyTorch has all you need to train your models; however, there’s much more to deep learning than attaching layers. When it comes to the actual training, there’s a lot of boilerplate code that one needs  to write. We have already seen this in our previous examples. This includes things like transferring data from CPU to GPU, implementing the training driver, etc.  Additionally, if one needs to scale training/inferencing on multiple devices/machines, there’s another set of integrations that often need to be done.

PyTorch Lightning is a solution that provides the APIs required to build models, datasets, and so on. The idea is that Lightning leaves the research logic to you while automating the rest of the boilerplate code. Additionally, features like multi-GPU training, FP16, training on TPU are brought in inherently by Lightning without requiring any code changes.

More details about PyTorch Lightning can be found at https://www.pytorchlightning.ai/tutorials

We adapt PyTorch Lightning for the rest of the chapters whenever we are writing code to train models. This helps us make the code more succinct and precise. 

In this context, let us revisit the problem of digit classification and show how it can be implemented within the Lightning framework. 

In [1]:
import torch
import torchmetrics
import pytorch_lightning as pl

from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from pytorch_lightning.callbacks import ModelCheckpoint

In [2]:
pl.seed_everything(42)

Global seed set to 42


42

## DataModule

A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data. All datamodules must inherit from LightningDataModule which provides methods to be overriden. 

In this specific case, we will implement MNIST as a datamodule. This datamodule can now be used across multiple experiments spanning different models, architectures.

In [3]:
class MNISTDataModule(LightningDataModule):
    DATASET_DIR = "datasets"
    
    def __init__(self, transform=None, batch_size=100):
        super(MNISTDataModule, self).__init__()
        if transform is None:
            # Default transform
            transform = transforms.Compose([transforms.Resize((32, 32)),
                                 transforms.ToTensor()])
        self.transform = transform
        self.batch_size = batch_size

    
    def prepare_data(self):
        """
        All the steps needed to download, tokenize, prepare the raw data should be done under
        prepare data. We will download the MNIST dataset here.
        """
        # Download the train data
        datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = True, download = True)
               
        # Download the test data
        datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = False, download = True)
    
    def setup(self, stage=None):
        """
        The steps to setup the dataset are usually done under setup method. 
        """
        train_dataset = datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = True, 
                                            download = False, transform=self.transform)
        # We will split the train dataset into train and validation sets.
        # All experiments are run using the train and val datasets
        self.train_dataset, self.val_dataset = random_split(train_dataset, [55000, 5000])
        self.test_dataset = datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = False, 
                                            download = False, transform=self.transform)
    
    
    def train_dataloader(self):
        """
        As evident by the name, this method is responsible for creating and returning the 
        train dataloader
        """
        return DataLoader(self.train_dataset, batch_size=self.batch_size, 
                          shuffle=True, num_workers=0) 
    
    def val_dataloader(self):
        """
        As evident by the name, this method is responsible for creating and returning the 
        val dataloader
        """
        return DataLoader(self.val_dataset, batch_size=self.batch_size, 
                          shuffle=False, num_workers=0) 
    
    def test_dataloader(self):
        """
        As evident by the name, this method is responsible for creating and returning the 
        val dataloader
        """
        return DataLoader(self.test_dataset, batch_size=self.batch_size, 
                                          shuffle=False, num_workers=0)
    
    @property
    def num_classes(self):
        return 10

## LightningModule

A LightningModule organizes your PyTorch code into 5 sections

1. Computations (init).
2. Train loop (training_step)
3. Validation loop (validation_step)
4. Test loop (test_step)
5. Optimizers (configure_optimizers)

Let us now see how we can define the LeNet classifier as a LightningModule

In [4]:
class LeNetClassifier(LightningModule):
    def __init__(self, num_classes):
        """
        In __init__ we typically define the model, the criterion and any other setup steps needed to be done
        for the training of the model.
        """        
        super(LeNetClassifier, self).__init__()
        self.save_hyperparameters()
        
        self.conv1 = torch.nn.Sequential(
                        torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
                        torch.nn.Tanh(),
                        torch.nn.AvgPool2d(kernel_size=2))
        self.conv2 = torch.nn.Sequential(
                        torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
                        torch.nn.Tanh(),
                        torch.nn.AvgPool2d(kernel_size=2))
        self.conv3 = torch.nn.Sequential(
                        torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
                        torch.nn.Tanh())
        self.fc1 = torch.nn.Sequential(
                        torch.nn.Linear(in_features=120, out_features=84),
                        torch.nn.Tanh())
        self.fc2 = torch.nn.Linear(in_features=84, out_features=num_classes)
        
        # We will use Cross Entropy Loss
        self.criterion = torch.nn.CrossEntropyLoss()
        
        self.accuracy = torchmetrics.Accuracy()

        
    def forward(self, X):
        """
        The forward method implements the forward pass of the model. In this case
        the input is a batch of images, and the output is the logits
        """
        # X is [batch_size, C, H, W] tensor
        conv_out = self.conv3(self.conv2(self.conv1(X)))
        batch_size = conv_out.shape[0]
        conv_out = conv_out.reshape(batch_size, -1)
        # Logits is [batch_size, num_classes] tensor
        logits = self.fc2(self.fc1(conv_out))
        return logits  
    
    def predict(self, X):
        """
        Predict runs the forward pass, performs softmax to convert the resulting logits into
        probabilities and returns the class with the highest probability.
        """
        logits = self.forward(X)
        probs = torch.softmax(logits, dim=1)
        return torch.argmax(probs, 1)
        
        
    def core_step(self, batch):
        """
        Both the training and test loops involve the forward pass, computation of loss
        and accuracy. Let us abstract it out and implement it under this method.
        """
        X, y_true = batch
        y_pred_logits = self.forward(X)
        loss = self.criterion(y_pred_logits, y_true)
        accuracy = self.accuracy(y_pred_logits, y_true)
        return loss, accuracy
        
    
    def training_step(self, batch, batch_idx):
        """
        This method implements the basic training step. We will run forward pass, compute 
        loss, accuracy. We will log any necessary values, and return the total loss.
        """
        loss, accuracy = self.core_step(batch)
        if self.global_step % 100 == 0:
            self.log("train_loss", loss, on_step=True, on_epoch=True)
            self.log("train_accuracy", accuracy, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx, dataset_idx=None):
        """
        This method implements the basic validation step. We will run the forward pass, compute the loss
        and accuracy and return it.
        """
        return self.core_step(batch)
    
    def validation_epoch_end(self, outputs):
        """
        This method will be called at the end of all test steps for each epoch i.e the validation epoch end.
        The output of every single test_step is available to via outputs. 
        
        Here we will compute the average test loss and accuracy by simply averaging across all test batches
        """
        avg_loss = torch.tensor([x[0] for x in outputs]).mean()
        avg_accuracy = torch.tensor([x[1] for x in outputs]).mean()
        self.log("val_loss", avg_loss)
        self.log("val_accuracy", avg_accuracy)
        print(f"Epoch {self.current_epoch}, Val loss: {avg_loss:0.2f}, Accuracy: {avg_accuracy:0.2f}")
        return avg_loss
    
    def configure_optimizers(self):
        """
        The optimizer will be configured in this method
        """
        return torch.optim.SGD(model.parameters(), lr=0.01,
                      momentum=0.9)
    
    def checkpoint_callback(self):
        """
        This callback determines the logic for how we want to checkpoint / save the model
        """
        # We will save the model with the best val accuracy.
        return ModelCheckpoint(monitor="val_accuracy", mode="max", save_top_k=1)

Notice how the model is independent of the data. This will allow us to potentially run the LeNetClassifier model on other data modules without any code change. 

Note that we are not doing the following steps
1. Moving the data to device
2. Calling loss.backward 
3. Calling optimizer.backward
4. Setting model.train() / eval()
5. Resetting the gradients
6. Implementing the trainer loop

All of these are taken care of by PyTorch Lightning, thus eliminating a lot of boiler plate code.

## Trainer

Now we are ready to train our model. This can be done using the Trainer class.


This abstraction achieves the following:

1. You maintain control over all aspects via PyTorch code without an added abstraction.
2. The trainer uses best practices embedded by contributors and users from top AI labs such as Facebook AI Research, NYU, MIT, Stanford, etc…
3. The trainer allows overriding any key part that you don’t want automated.


In [5]:
dm = MNISTDataModule()
model = LeNetClassifier(num_classes=dm.num_classes)
exp_dir = "/tmp/mnist"
trainer = Trainer(
        default_root_dir=exp_dir, # The experiment directory
        callbacks=[model.checkpoint_callback()],
        gpus=torch.cuda.device_count(), # Number of GPUs to run on
        max_epochs=10,
        num_sanity_val_steps=0
    )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [6]:
trainer.fit(model, dm)


  | Name      | Type             | Params
-----------------------------------------------
0 | conv1     | Sequential       | 156   
1 | conv2     | Sequential       | 2.4 K 
2 | conv3     | Sequential       | 48.1 K
3 | fc1       | Sequential       | 10.2 K
4 | fc2       | Linear           | 850   
5 | criterion | CrossEntropyLoss | 0     
6 | accuracy  | Accuracy         | 0     
-----------------------------------------------
61.7 K    Trainable params
0         Non-trainable params
61.7 K    Total params
0.247     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 0, Val loss: 0.24, Accuracy: 0.93


Validating: 0it [00:00, ?it/s]

Epoch 1, Val loss: 0.14, Accuracy: 0.96


Validating: 0it [00:00, ?it/s]

Epoch 2, Val loss: 0.10, Accuracy: 0.97


Validating: 0it [00:00, ?it/s]

Epoch 3, Val loss: 0.08, Accuracy: 0.98


Validating: 0it [00:00, ?it/s]

Epoch 4, Val loss: 0.07, Accuracy: 0.98


Validating: 0it [00:00, ?it/s]

Epoch 5, Val loss: 0.06, Accuracy: 0.98


Validating: 0it [00:00, ?it/s]

Epoch 6, Val loss: 0.06, Accuracy: 0.98


Validating: 0it [00:00, ?it/s]

Epoch 7, Val loss: 0.06, Accuracy: 0.98


Validating: 0it [00:00, ?it/s]

Epoch 8, Val loss: 0.05, Accuracy: 0.98


Validating: 0it [00:00, ?it/s]

Epoch 9, Val loss: 0.05, Accuracy: 0.99


Note that we did not to write the trainer loop either. We just need to call trainer.fit to train the model. 

Additionally, the logging automatically enables us to look at the loss and accuracy curves via TensorBoard. This can be done by running `tensorboad --logdir /tmp/mnist` 

In [7]:
# Inference follows a similar pattern as before. 

model.eval()
X, y_true = next(iter(dm.test_dataloader()))
with torch.no_grad():
    y_pred = model.predict(X)