<div style="line-height:1.2;">

<h1 style="color:#BF66F2; margin-bottom: 0.3em;"> Convolutional Neural Networks in PyTorch 6 </h1>

<h4 style="margin-top: 0.3em; margin-bottom: 1em;"> PyTorch Lightning tutorial with CNN trained on MNIST dataset. Focus on EarlyStopping. </h4>

<div style="line-height:1.4; margin-bottom: 0.5em;">
    <h3 style="color: lightblue; display: inline; margin-right: 0.5em;">Keywords:</h3> 
    pytorch_lightning.callbacks + cross_entropy + deepspeed
</div>

</div>

<h1 style="color:#BF66F2 ">Convolutional Neural Networks in PyTorch 6 </h1>
<div style="margin-top: -30px;">
<h4> PyTorch Lightning tutorial with CNN trained on MNIST dataset. Focus on EarlyStopping. </h4>
</div>
<div style="margin-top: -17px;">
<span style="display: inline-block;">
    <h3 style="color: lightblue; display: inline;">Keywords:</h3>
    pytorch_lightning.callbacks + EarlyStopping
    </span>
</div>

<h3 style="color:#BF66F2 "> Recap: </h3>
<div style="margin-top: -22px;">
PyTorch Lightning is a lightweight PyTorch wrapper that simplifies the training and deployment of deep learning models, (including CNNs), providing a high-level interface and a set of abstractions. <br>
With the goal of making the process of training CNNs more concise, modular, and reproducible, allowing users to focus more on the core model <br> architecture and training logic.
</div>

In [1]:
%%script echo Skipping, since already installed
!pip install torchmetrics

Skipping, since already installed


In [2]:
%%script echo Skipping, since already installed
!pip install pytorch_lightning

Skipping, since already installed


In [3]:
import torch
import torchmetrics
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from tqdm import tqdm

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, EarlyStopping

In [4]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  #to ignore CUDA warnings when GPU is not in use

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
class CNNLightning(pl.LightningModule):
    """ PyTorch Lightning module for a vanilla CNN model.

    Args:
        - Learning rate for the optimizer  [float default: 3e-4]
        - Number of input channels [int default: 1]
        - Number of output classes [int default: 10]

    Details:
        - 'pl.LightningModule' is the base class for Lightning modules.
        - 'torchmetrics.Accuracy' is used to compute accuracy metric.
        - 'nn.Conv2d' is a 2D convolutional layer.
        - 'nn.MaxPool2d' is a max pooling layer.
        - 'nn.Linear' is a fully connected layer.
    """
    def __init__(self, lr=3e-4, in_channels=1, num_classes=10):
        super().__init__()
        self.lr = lr
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=8,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=8,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.fc1 = nn.Linear(16 * 7 * 7, num_classes)
        self.lr = lr

    def training_step(self, batch, batch_idx):
        """ Perform a single training step.

        Parameters:
            - batch: Tuple containing input data and labels
            - Index of the current batch [int]

        Details:
            - _common_step is a helper method that computes the forward pass.
            - criterion is the loss function used for computing the loss.
            - self.log logs the training accuracy.

        Returns:
            Computed loss value for the batch
        """
        x, y = batch
        y_hat = self._common_step(x, batch_idx)
        loss = criterion(y_hat, y)
        accuracy = self.train_acc(y_hat, y)
        self.log(
            "train_acc_step",
            self.train_acc,
            on_step=True,
            on_epoch=False,
            prog_bar=True,
        )

        return loss

    def on_train_epoch_end(self):
        """ Execute at the end of each training epoch. """
        # Reset the training accuracy metric
        self.train_acc.reset()

    def test_step(self, batch, batch_idx):
        """ Perform a single test step.

        Parameters:
            batch: A tuple containing input data and labels.
            batch_idx (int): The index of the current batch.

        Details:
            - '_common_step' is a helper method that computes the forward pass.
            - 'F.cross_entropy' is the loss function used for computing the loss.
            - 'self.log' logs the test loss and accuracy.
        """
        x, y = batch
        y_hat = self._common_step(x, batch_idx)
        loss = F.cross_entropy(y_hat, y)
        accuracy = self.test_acc(y_hat, y)
        self.log("test_loss", loss, on_step=True)
        self.log("test_acc", accuracy, on_step=True)

    def validation_step(self, batch, batch_idx):
        """ Perform a single validation step.

        Parameters:
            - Tuple containing input data and labels
            - Index of the current batch [int]

        Details:
            - '_common_step' is a helper method that computes the forward pass.
            - 'F.cross_entropy' is the loss function used for computing the loss.
            - 'self.log' logs the validation loss and accuracy.
        """
        x, y = batch
        y_hat = self._common_step(x, batch_idx)
        loss = F.cross_entropy(y_hat, y)
        accuracy = self.test_acc(y_hat, y)
        self.log("val_loss", loss, on_step=True)
        self.log("val_acc", accuracy, on_step=True)

    def predict_step(self, batch, batch_idx):
        """ Perform a single prediction step.

        Parameters:
            - batch: Tuple containing input data and labels.
            - Index of the current batch [int]

        Details:
            - '_common_step' is a helper method that computes the forward pass.

        Returns:
            Predicted output for the batch (y_hat)

        """
        x, y = batch
        y_hat = self._common_step(x)
        return y_hat

    def _common_step(self, x, batch_idx):
        """ Perform the common steps of the forward pass.

        Parameters:
            - Input data
            - Index of the current batch [int]

        Details:
            - 'self.pool' applies max pooling to the input.
            - 'F.relu' applies the ReLU activation function.
            - 'x.reshape' reshapes the tensor.

        Returns:
            The predicted output (y_hat)
        """
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(x.shape[0], -1)
        y_hat = self.fc1(x)
        return y_hat

    def configure_optimizers(self):
        """ Define Adam as optimizer for the model. """
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [7]:
class MNISTDataModule(pl.LightningDataModule):
    """ MNISTDataModule is a PyTorch Lightning data module for loading and preprocessing MNIST dataset.

    Args:
        Batch size for the data loaders (default: 512) [int]

    Details:
        - 'pl.LightningDataModule' is the base class for Lightning data modules.
        - 'datasets.MNIST' is the MNIST dataset class from torchvision.
        - 'transforms.ToTensor()' converts the image data to tensors.
        - 'torch.utils.data.random_split' is used to split the dataset into training and validation sets.
        - 'DataLoader' is used to create data loaders for the datasets.
    """
    def __init__(self, batch_size=512):
        """ Initialize the Data Module """
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage):
        """ Set up the MNIST dataset for training, validation, and testing.

        Parameters:
            Current stage (e.g., "fit", "validate", "test").

        Details:
            'torch.utils.data.random_split' => splits the dataset into training and validation sets.
        """
        mnist_full = train_dataset = datasets.MNIST(
            root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
        self.mnist_test = datasets.MNIST(
            root="dataset/", train=False, transform=transforms.ToTensor(), download=True)
        self.mnist_train, self.mnist_val = torch.utils.data.random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        """ Create the DataLoader object for training. """
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=6,
            shuffle=True,)

    def val_dataloader(self):
        """  Create the DataLoader object for validation. """
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=2, shuffle=False)

    def test_dataloader(self):
        """ Create the DataLoader object for the test set. """
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=2, shuffle=False)

In [8]:
class MyPrintingCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        print("Training is starting")
    def on_train_epoch_end(self, trainer, pl_module):
        print("Training is ending")

In [9]:
precision = "medium"
torch.set_float32_matmul_precision(precision)
criterion = nn.CrossEntropyLoss()

<h3 style="color:#BF66F2 ">  Example 1 </h3>

In [10]:
""" Define the Trainer with necessary arguments.
N.B.
#TypeError: Trainer.__init__() got an unexpected keyword argument 'gpus'
#trainer = pl.Trainer(
#    max_epochs = 10,
#    gpus = 1 if torch.cuda.is_available() else 0,
#    callbacks=[MyPrintingCallback()],)
"""
# Instantiate the LightningModule and LightningDataModule
model = CNNLightning()
data_module = MNISTDataModule()

trainer = pl.Trainer(max_epochs=20, accelerator="auto", callbacks=[MyPrintingCallback()],)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


<div style="line-height:0.64">
<h3 style="color:#BF66F2 ">  Example 2 </h3>
</div>
Must run with GPU device, otherwise the "MisconfigurationException: No supported gpu backend found!" is raised. <br>
from "pytorch_lightning/trainer/connectors/accelerator_connector.py".

In [11]:
%%script echo Skipping, since already installed
# To use DeepSpeedStrategy
!pip install -U deepspeed

Skipping, since already installed


# Training strategies:
- Training with the DistributedDataParallel strategy on many GPUs:<br>
strategy="ddp"

- Training with the DistributedDataParallel strategy on many GPUs, with options configured:<br>
strategy=DDPStrategy(static_graph=True)

- Training with the DDP Spawn strategy: <br>
strategy="ddp_spawn

- Training with the DeepSpeed strategy on available GPUs: <br>
strategy="deepspeed"

=> On Colab use: strategy="ddp_notebook"

In [12]:
%%script echo Skipping, since already installed
!pip install "ray[tune]"

Skipping, since already installed


In [13]:
%%script echo Skipping, since already installed
!pip install "pytorch-lightning-bolts>=0.2.5"

Skipping, since already installed


In [14]:
import os
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from ray.tune.integration.pytorch_lightning import TuneReportCallback

In [1]:
%%script echo Skipping, since useless
""" Use 20% of training data for validation. """
train_set_size = int(len(train_dataset) * 0.8)
valid_set_size = len(train_dataset) - train_set_size

""" Split the train set into two """
seed = torch.Generator().manual_seed(42)
train_dataset, val_dataset = torch.utils.data.random_split(
    train_dataset, [train_set_size, valid_set_size], generator=seed)

Skipping, since useless


In [15]:
# Initialize network
model_lightning = CNNLightning()

trainer = pl.Trainer(
    #fast_dev_run=True,
    # overfit_batches=3,
    max_epochs=5,
    precision='16-mixed',           #16 is supported for historical reasons but its usage is discouraged
    accelerator="gpu",
    #devices=[0,1],                 #in case of more GPUs are available
    devices=[0],
    callbacks=[EarlyStopping(monitor="val_loss", mode="min")],
    enable_model_summary=True,
    profiler="simple",
    strategy="ddp_notebook",
)

dm = MNISTDataModule()
trainer.fit(model=model_lightning, datamodule=dm,)

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:lightning_fabric.utilities.distributed:Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
INFO:pytorch_lightning.utilities.rank_zero:----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type     

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
INFO:pytorch_lightning.profilers.profiler:FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                         	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                              

In [16]:
# Tune the trainer first to find best batch size and lr
#trainer.tune(model_lightning, dm)

# Train the model


# Test model on test loader from LightningDataModule
trainer.test(model=model_lightning, datamodule=dm)

INFO:lightning_fabric.utilities.distributed:Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
INFO:pytorch_lightning.utilities.rank_zero:----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]



INFO:pytorch_lightning.profilers.profiler:TEST Profiler Report

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                   	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                    	|  -              	|  36

[{'test_loss_epoch': 0.2383224368095398, 'test_acc_epoch': 0.9337999820709229}]