# 🚀 Getting started with PyTorch Lightning

In this notebook, we’ll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).

> This notebook is heavily inspired by the original material provided by lightning.ai on [Introduction to PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/mnist-hello-world.html) and [Data Modules](https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/datamodules.html).

In [None]:
from functools import partial
from pathlib import Path
from typing import Any, Dict, Tuple

import lightning as L
import torch
from lightning import LightningModule
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.classification.accuracy import Accuracy
from torchvision import transforms as T
from torchvision.datasets import MNIST

In [None]:
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

BASE_DIR = Path().cwd().resolve().parent
MODEL_DIR = BASE_DIR / "models"
LOG_DIR = BASE_DIR / "logs"
MNIST_DIR = BASE_DIR / "data"

`DataModules` decouple data-related hooks from `LightningModule`, enabling dataset-agnostic models.

- `__init__`: Defines `data_dir`, a common transform for all splits, and `self.dims`.  
- `prepare_data`: Downloads the dataset (if needed) without making state assignments.  
- `setup`: Loads data, prepares PyTorch datasets, and handles logic for ‘fit’ and ‘test’ stages. Runs safely across GPUs.  
- Dataloaders: `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` return `DataLoader` instances wrapping the prepared datasets.

In [None]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: Path,
        num_workers: int = 1,
        pin_memory: bool = True,
        batch_size: int = 16,
        transforms: list | None = None,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.data_dir = data_dir
        self.dims = (1, 28, 28)
        self.num_classes = 10
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.batch_size = batch_size
        self.transforms = transforms or []
        self._default_transforms = [
            T.ToTensor(),
            T.Normalize((0.1307,), (0.3081,)),
        ]

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            transform = T.Compose(self.transforms + self._default_transforms)

            mnist_full = MNIST(self.data_dir, train=True, transform=transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            transform = T.Compose(self._default_transforms)

            self.mnist_test = MNIST(self.data_dir, train=False, transform=transform)

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            shuffle=True,
            batch_size=self.batch_size,
        )

    def val_dataloader(self):
        return DataLoader(
            self.mnist_val,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            shuffle=False,
            batch_size=self.batch_size,
        )

    def test_dataloader(self):
        return DataLoader(
            self.mnist_test,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            shuffle=False,
            batch_size=self.batch_size,
        )

In [None]:
class MNISTLitModule(LightningModule):
    """Example of a `LightningModule` for MNIST classification."""

    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler._LRScheduler,
        input_size: int = 784,
        lin1_size: int = 256,
        lin2_size: int = 256,
        lin3_size: int = 256,
        output_size: int = 10,
        compile: bool = True,
    ) -> None:
        """
        Initializes the MNISTLitModule.

        Args:
            optimizer (torch.optim.Optimizer): The optimizer to use.
            scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler to use.
            input_size (int, optional): The size of the input layer. Defaults to 784.
            lin1_size (int, optional): The size of the first linear layer. Defaults to 256.
            lin2_size (int, optional): The size of the second linear layer. Defaults to 256.
            lin3_size (int, optional): The size of the third linear layer. Defaults to 256.
            output_size (int, optional): The size of the output layer. Defaults to 10.
            compile (bool, optional): Whether to compile the module or not. Defaults to True.

        Returns:
            None
        """
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False, ignore=["net"])

        self.net = nn.Sequential(
            nn.Linear(input_size, lin1_size),
            nn.BatchNorm1d(lin1_size),
            nn.ReLU(),
            nn.Linear(lin1_size, lin2_size),
            nn.BatchNorm1d(lin2_size),
            nn.ReLU(),
            nn.Linear(lin2_size, lin3_size),
            nn.BatchNorm1d(lin3_size),
            nn.ReLU(),
            nn.Linear(lin3_size, output_size),
        )

        # loss function
        self.criterion = torch.nn.CrossEntropyLoss()

        # metric objects for calculating and averaging accuracy across batches
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        self.test_acc = Accuracy(task="multiclass", num_classes=10)

        # for averaging loss across batches
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        # for tracking best so far validation accuracy
        self.val_acc_best = MaxMetric()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Performs a forward pass through the model `self.net`.

        Args:
            x: A tensor of images.

        Returns:
            A tensor of logits.
        """
        batch_size, _, _, _ = x.size()

        x = x.view(batch_size, -1)

        return self.net(x)

    def on_train_start(self) -> None:
        """Lightning hook that is called when training begins."""
        # by default lightning executes validation step sanity checks before training starts,
        # so it's worth to make sure validation metrics don't store results from these checks
        self.val_loss.reset()
        self.val_acc.reset()
        self.val_acc_best.reset()

    def model_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Performs a single model step on a batch of data.

        Args:
            batch: A batch of data (a tuple) containing the input tensor of images and target
                labels.

        Returns:
            A tuple containing (in order):
                - A tensor of losses.
                - A tensor of predictions.
                - A tensor of target labels.
        """
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Perform a single training step on a batch of data from the training set.

        Args:
            batch (Tuple[torch.Tensor, torch.Tensor]): A batch of data containing the input tensor
                of images and target labels.
            batch_idx (int): The index of the current batch.

        Returns:
            torch.Tensor: A tensor of losses between model predictions and targets.
        """
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.train_loss(loss)
        self.train_acc(preds, targets)
        self.log(
            "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True
        )
        self.log(
            "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True
        )

        # return loss or backpropagation will fail
        return loss

    def on_train_epoch_end(self) -> None:
        "Lightning hook that is called when a training epoch ends."
        pass

    def validation_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        """Perform a single validation step on a batch of data from the validation set.

        Args:
            batch: A tuple containing the input tensor of images and target labels.
            batch_idx: The index of the current batch.
        """
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.val_loss(loss)
        self.val_acc(preds, targets)
        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)

    def on_validation_epoch_end(self) -> None:
        "Lightning hook that is called when a validation epoch ends."
        acc = self.val_acc.compute()  # get current val acc
        self.val_acc_best(acc)  # update best so far val acc
        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
        # otherwise metric would be reset by lightning after each epoch
        self.log(
            "val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True
        )

    def test_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        """Performs a single test step on a batch of data from the test set.

        Args:
            batch (Tuple[torch.Tensor, torch.Tensor]): A batch of data containing the input tensor
                of images and target labels.
            batch_idx (int): The index of the current batch.
        """
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.test_loss(loss)
        self.test_acc(preds, targets)
        self.log(
            "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True
        )
        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)

    def on_test_epoch_end(self) -> None:
        """Lightning hook that is called when a test epoch ends."""
        pass

    def setup(self, stage: str) -> None:
        """Called at the beginning of fit (train + validate), validate, test, or predict.

        This is a good place to build models dynamically or adjust something about them. This
        hook is called on every process when using DDP.

        Args:
            stage: One of "fit", "validate", "test", or "predict".
        """
        if self.hparams.compile and stage == "fit":
            self.net = torch.compile(self.net)

    def configure_optimizers(self) -> Dict[str, Any]:
        """Configure optimizers and learning-rate schedulers.

        Returns:
            A dict containing the configured optimizers and learning-rate schedulers to be used for training.
        """
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}

Now, we initialize and train the LitModel using the MNISTDataModule’s configuration settings and dataloaders.

In [None]:
logger = TensorBoardLogger(save_dir=LOG_DIR, name="tmp")

trainer = L.Trainer(
    accelerator="auto",
    strategy="auto",
    devices="auto",
    num_nodes=1,
    logger=logger,
)

In [None]:
L.seed_everything(seed=42, workers=True)
loggers = [
    CSVLogger(LOG_DIR, name="csv"),
    TensorBoardLogger(LOG_DIR, name="tensorboard", log_graph=True),
]

callbacks = [
    ModelCheckpoint(
        dirpath=MODEL_DIR,
        filename="{epoch}_{val/loss:.2f}_{val_accuracy:.2f}",
        save_top_k=10,
        monitor="val/loss",
        mode="min",
    ),
    EarlyStopping(
        monitor="val/loss", min_delta=2e-4, patience=8, verbose=False, mode="min"
    ),
    LearningRateMonitor(logging_interval="step"),
]

trainer = L.Trainer(
    fast_dev_run=False,
    accelerator="auto",
    strategy="auto",
    devices="auto",
    num_nodes=1,
    logger=loggers,
    callbacks=callbacks,
    max_epochs=10,
    min_epochs=5,
    overfit_batches=0.0,
    log_every_n_steps=10,
)

transforms = [
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(degrees=45),
]

datamodule = MNISTDataModule(MNIST_DIR, batch_size=1024, transforms=transforms)
model = MNISTLitModule(
    optimizer=torch.optim.Adam,
    scheduler=partial(torch.optim.lr_scheduler.StepLR, step_size=1, gamma=0.1),
    compile=False,
)
trainer.fit(model, datamodule)

In [None]:
%load_ext tensorboard
%tensorboard --logdir={LOG_DIR}