## PyTorch Lightning for MNIST

Lighting to easily train and test a model on MNIST dataset.

In [1]:
! pip install lightning torchmetrics --upgrade
! pip install torch torchvision --upgrade

Collecting lightning
  Downloading lightning-2.5.5-py3-none-any.whl.metadata (39 kB)
Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Downloading lightning-2.5.5-py3-none-any.whl (828 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m828.5/828.5 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchmetrics, lightning
  Attempting uninstall: torchmetrics
    Found existing installation: torchmetrics 1.7.3
    Uninstalling torchmetrics-1.7.3:
      Successfully uninstalled torchmetrics-1.7.3
  Attempting uninstall: lightning
    Found existing installation: lightning 2.4.0
    Uninstalling lightning-2.4.0:
      Successfully uninstalled lightning-2.4.0
Successfully installed lightning-2.5.5 torchmet

### DataModule for MNIST



In [8]:
import torch
import torchvision
import lightning as L
import torch.nn as nn
from torchvision import transforms

# create an MNIST datamodule

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, batch_size: int = 32):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        torchvision.datasets.MNIST(root="./data", download=True)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.mnist_train = torchvision.datasets.MNIST(
                root="./data", train=True, download=True, transform=transforms.ToTensor()
            )
        if stage == "test" or stage is None:
            self.mnist_test = torchvision.datasets.MNIST(
                root="./data", train=False,download=True, transform=transforms.ToTensor()
            )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.mnist_train, batch_size=self.batch_size, shuffle=True
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.mnist_test, batch_size=self.batch_size, shuffle=False
        )


### Model for MNIST

In [9]:
# 3-layer CNN model
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(800, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        y = self.features(x)
        y = torch.flatten(y, 1)
        y = self.classifier(y)
        
        return y

### LightningModule for MNIST

Notice this module uses the sample `CNN()` model defined above.

In [10]:
# build a pytorch lightning model
from torchmetrics import Accuracy

class LitModel(L.LightningModule):
    def __init__(self, model=CNN()):
        super().__init__()
        self.model = model
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="multiclass", num_classes=10)
        self.avg_acc = []
        self.avg_loss = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_train_epoch_end(self):
        lr = self.trainer.optimizers[0].param_groups[0]["lr"]
        self.log("learning_rate", lr, on_step=False, on_epoch=True, prog_bar=True, logger=True)    

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        # linear decay
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lambda epoch: 1 - epoch / self.trainer.max_epochs)
        return [optimizer], [scheduler]

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.avg_acc.append(acc)
        self.avg_loss.append(loss)
        return loss, acc

    def on_test_epoch_end(self):
        loss = torch.stack(self.avg_loss).mean()
        acc = torch.stack(self.avg_acc).mean()
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.avg_acc = []
        self.avg_loss = []
        return loss, acc

### Training and Testing the Model

Note how we use the `Trainer` class to train and test the model. Also, note the selection of accelerator and precision.

In [11]:
# train for 10 epochs

dm = MNISTDataModule()
model = LitModel()
trainer = L.Trainer(max_epochs=10, accelerator="auto")
trainer.fit(model, dm)
trainer.test(model, dm)


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name     | Type               | Params | Mode 
--------------------------------------------------------
0 | model    | CNN                | 108 K  | train
1 | loss_fn  | CrossEntropyLoss   | 0      | train
2 | accuracy | MulticlassAccuracy | 0      | train
--------------------------------------------------------
108 K     Trainable params
0         Non-trainable params
108 K     Total params
0.434     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 1875/1875 [00:33<00:00, 56.12it/s, v_num=1, train_loss=0.00237, learning_rate=0.0001]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 1875/1875 [00:33<00:00, 56.09it/s, v_num=1, train_loss=0.00237, learning_rate=0.0001]
Testing DataLoader 0: 100%|██████████| 313/313 [00:03<00:00, 100.85it/s]


[{'test_loss': 0.030096787959337234, 'test_acc': 0.9915135502815247}]